コード例 #1
0
class EventSource(RequestHandler):
    def initialize(self, stream):
        #assert isinstance(stream, Stream)
        self.stream = stream
        self.messages = Queue()
        self.finished = False
        self.set_header('content-type', 'text/event-stream')
        self.set_header('cache-control', 'no-cache')
        self.store = self.stream.sink(self.messages.put)

    @gen.coroutine
    def publish(self, message):
        """Pushes data to a listener."""
        try:
            self.write(message >> to_str)
            yield self.flush()
        except StreamClosedError:
            self.finished = True
            (self.request.remote_ip, StreamClosedError) >> log

    @gen.coroutine
    def get(self, *args, **kwargs):
        try:
            while not self.finished:
                message = yield self.messages.get()
                yield self.publish(message)
        except Exception:
            pass
        finally:
            self.store.destroy()
            self.messages.empty()
            self.finish()
コード例 #2
0
ファイル: dask.py プロジェクト: kszucs/streams
class gather(Stream):
    def __init__(self, child, limit=10, client=None):
        self.client = client or default_client()
        self.queue = Queue(maxsize=limit)
        self.condition = Condition()

        Stream.__init__(self, child)

        self.client.loop.add_callback(self.cb)

    def update(self, x, who=None):
        return self.queue.put(x)

    @gen.coroutine
    def cb(self):
        while True:
            x = yield self.queue.get()
            L = [x]
            while not self.queue.empty():
                L.append(self.queue.get_nowait())
            results = yield self.client._gather(L)
            for x in results:
                yield self.emit(x)
            if self.queue.empty():
                self.condition.notify_all()

    @gen.coroutine
    def flush(self):
        while not self.queue.empty():
            yield self.condition.wait()
コード例 #3
0
class ConnectionPool(object):
    def __init__(self, servers, maxsize=15, minsize=1, loop=None, debug=0):
        loop = loop if loop is not None else tornado.ioloop.IOLoop.instance()
        if debug:
            logging.basicConfig(
                level=logging.DEBUG,
                format="'%(levelname)s %(asctime)s"
                " %(module)s:%(lineno)d %(process)d %(thread)d %(message)s'")
        self._loop = loop
        self._servers = servers
        self._minsize = minsize
        self._debug = debug
        self._in_use = set()
        self._pool = Queue(maxsize)

    @gen.coroutine
    def clear(self):
        """Clear pool connections."""
        while not self._pool.empty():
            conn = yield self._pool.get()
            conn.close_socket()

    def size(self):
        return len(self._in_use) + self._pool.qsize()

    @gen.coroutine
    def acquire(self):
        """Acquire connection from the pool, or spawn new one
        if pool maxsize permits.

        :return: ``Connetion`` (reader, writer)
        """
        while self.size() < self._minsize:
            _conn = yield self._create_new_conn()
            yield self._pool.put(_conn)

        conn = None
        while not conn:
            if not self._pool.empty():
                conn = yield self._pool.get()

            if conn is None:
                conn = yield self._create_new_conn()

        self._in_use.add(conn)
        raise gen.Return(conn)

    @gen.coroutine
    def _create_new_conn(self):
        conn = yield Connection.get_conn(self._servers, self._debug)
        raise gen.Return(conn)

    def release(self, conn):
        self._in_use.remove(conn)
        try:
            self._pool.put_nowait(conn)
        except (QueueEmpty, QueueFull):
            conn.close_socket()
コード例 #4
0
ファイル: batched.py プロジェクト: broxtronix/distributed
class BatchedStream(object):
    """ Mostly obsolete, see BatchedSend """

    def __init__(self, stream, interval):
        self.stream = stream
        self.interval = interval / 1000.0
        self.last_transmission = default_timer()
        self.send_q = Queue()
        self.recv_q = Queue()
        self._background_send_coroutine = self._background_send()
        self._background_recv_coroutine = self._background_recv()
        self._broken = None

        self.pc = PeriodicCallback(lambda: None, 100)
        self.pc.start()

    @gen.coroutine
    def _background_send(self):
        with log_errors():
            while True:
                msg = yield self.send_q.get()
                if msg == "close":
                    break
                msgs = [msg]
                now = default_timer()
                wait_time = self.last_transmission + self.interval - now
                if wait_time > 0:
                    yield gen.sleep(wait_time)
                while not self.send_q.empty():
                    msgs.append(self.send_q.get_nowait())

                try:
                    yield write(self.stream, msgs)
                except StreamClosedError:
                    self.recv_q.put_nowait("close")
                    self._broken = True
                    break

                if len(msgs) > 1:
                    logger.debug("Batched messages: %d", len(msgs))
                for _ in msgs:
                    self.send_q.task_done()

    @gen.coroutine
    def _background_recv(self):
        with log_errors():
            while True:
                try:
                    msgs = yield read(self.stream)
                except StreamClosedError:
                    self.recv_q.put_nowait("close")
                    self.send_q.put_nowait("close")
                    self._broken = True
                    break
                assert isinstance(msgs, list)
                if len(msgs) > 1:
                    logger.debug("Batched messages: %d", len(msgs))
                for msg in msgs:
                    self.recv_q.put_nowait(msg)

    @gen.coroutine
    def flush(self):
        yield self.send_q.join()

    @gen.coroutine
    def send(self, msg):
        if self._broken:
            raise StreamClosedError("Batch Stream is Closed")
        else:
            self.send_q.put_nowait(msg)

    @gen.coroutine
    def recv(self):
        result = yield self.recv_q.get()
        if result == "close":
            raise StreamClosedError("Batched Stream is Closed")
        else:
            raise gen.Return(result)

    @gen.coroutine
    def close(self):
        yield self.flush()
        raise gen.Return(self.stream.close())

    def closed(self):
        return self.stream.closed()
コード例 #5
0
ファイル: batched.py プロジェクト: wanjinchang/distributed
class BatchedStream(object):
    """ Mostly obsolete, see BatchedSend """
    def __init__(self, stream, interval):
        self.stream = stream
        self.interval = interval / 1000.
        self.last_transmission = default_timer()
        self.send_q = Queue()
        self.recv_q = Queue()
        self._background_send_coroutine = self._background_send()
        self._background_recv_coroutine = self._background_recv()
        self._broken = None

        self.pc = PeriodicCallback(lambda: None, 100)
        self.pc.start()

    @gen.coroutine
    def _background_send(self):
        with log_errors():
            while True:
                msg = yield self.send_q.get()
                if msg == 'close':
                    break
                msgs = [msg]
                now = default_timer()
                wait_time = self.last_transmission + self.interval - now
                if wait_time > 0:
                    yield gen.sleep(wait_time)
                while not self.send_q.empty():
                    msgs.append(self.send_q.get_nowait())

                try:
                    yield write(self.stream, msgs)
                except StreamClosedError:
                    self.recv_q.put_nowait('close')
                    self._broken = True
                    break

                if len(msgs) > 1:
                    logger.debug("Batched messages: %d", len(msgs))
                for _ in msgs:
                    self.send_q.task_done()

    @gen.coroutine
    def _background_recv(self):
        with log_errors():
            while True:
                try:
                    msgs = yield read(self.stream)
                except StreamClosedError:
                    self.recv_q.put_nowait('close')
                    self.send_q.put_nowait('close')
                    self._broken = True
                    break
                assert isinstance(msgs, list)
                if len(msgs) > 1:
                    logger.debug("Batched messages: %d", len(msgs))
                for msg in msgs:
                    self.recv_q.put_nowait(msg)

    @gen.coroutine
    def flush(self):
        yield self.send_q.join()

    @gen.coroutine
    def send(self, msg):
        if self._broken:
            raise StreamClosedError('Batch Stream is Closed')
        else:
            self.send_q.put_nowait(msg)

    @gen.coroutine
    def recv(self):
        result = yield self.recv_q.get()
        if result == 'close':
            raise StreamClosedError('Batched Stream is Closed')
        else:
            raise gen.Return(result)

    @gen.coroutine
    def close(self):
        yield self.flush()
        raise gen.Return(self.stream.close())

    def closed(self):
        return self.stream.closed()
コード例 #6
0
class ChannelConfiguration:
    _USER_CLOSE_CODE = 0
    _NORMAL_CLOSE_CODE = 200
    _NO_ROUTE_CODE = 312

    logger = logger
    connection: AsyncConnection

    def __init__(self, connection: AsyncConnection, io_loop, exchange=None, exchange_type=None, queue=None,
                 routing_key=None, durable=False, auto_delete=False, prefetch_count=None):

        self._channel = None
        self.connection = connection
        self._io_loop = io_loop
        self._channel_queue = Queue(maxsize=1)
        self._queue = queue if queue is not None else ""
        self._exchange = exchange
        if exchange_type is None:
            exchange_type = 'topic'
        self._exchange_type = exchange_type
        self._routing_key = routing_key
        self._durable = durable
        self._auto_delete = auto_delete
        if prefetch_count is None:
            prefetch_count = 1
        self._prefetch_count = prefetch_count
        self._should_consume = False
        self._consume_params = dict()

    @gen.coroutine
    def consume(self, on_message_callback, handler=None, no_ack=False):
        self.logger.info(f"[start consuming] routing key: {self._routing_key}; queue name: {self._queue}")
        channel = yield self._get_channel()

        self._should_consume = True
        self._consume_params = [on_message_callback, handler, no_ack]
        if handler is not None:
            channel.basic_consume(
                queue=self._queue,
                auto_ack=no_ack,
                on_message_callback=functools.partial(
                    on_message_callback,
                    handler=handler
                )
            )
        else:
            channel.basic_consume(queue=self._queue, on_message_callback=on_message_callback, auto_ack=no_ack)

    @gen.coroutine
    def publish(self, body, mandatory=None, properties=None, reply_to=None):
        channel = yield self._get_channel()
        if reply_to is not None:
            exchange = ""
            routing_key = reply_to
        else:
            exchange = self._exchange
            routing_key = self._routing_key

        self.logger.info(f"Publishing message. exchange: {exchange}; routing_key: {routing_key}")
        channel.basic_publish(exchange=exchange, routing_key=routing_key, body=body,
                              mandatory=mandatory, properties=properties)

    @gen.coroutine
    def _get_channel(self):
        if self._channel_queue.empty():
            yield self._create_channel()

        channel = yield self._top()
        return channel

    @gen.coroutine
    def _top(self):
        channel = yield self._channel_queue.get()
        self._channel_queue.put(channel)
        return channel

    def _remove_channel_from_queue(self):
        try:
            self._channel_queue.get_nowait()
        except QueueEmpty:
            pass

    @gen.coroutine
    def _create_channel(self):
        self.logger.info("creating channel")
        connection = yield self.connection.get_connection()

        def on_channel_flow(*args, **kwargs):
            pass

        def on_channel_cancel(frame):
            self.logger.error("Channel was canceled")
            if not self._channel_queue.empty():
                channel = self._channel
                if channel and not channel.is_close or channel.is_closing:
                    channel.close()

        def on_channel_closed(channel, reason):
            reply_code, reply_txt = reason.args
            self.logger.info(f'Channel {channel} was closed: {reason}')

            if reply_code not in [self._NORMAL_CLOSE_CODE, self._USER_CLOSE_CODE]:
                self.logger.error(f"Channel closed. reply code: {reply_code}; reply text: {reply_txt}. "
                                  f"System will exist")
                if connection and not (connection.is_closed or connection.is_closing):
                    connection.close()

                self._remove_channel_from_queue()
                self._io_loop.call_later(1, self._create_channel)
            else:
                self.logger.info(f"Reply code: {reply_code}, reply text: {reply_txt}")

        def on_channel_return(channel, method, property, body):
            """"If publish message has failed, this method will be invoked."""
            self.logger.error(f"Rejected from server. reply code: {method.reply_code}, reply text: {method.reply_txt}")
            raise Exception("Failed to publish message.")

        def open_callback(channel):
            self.logger.info("Created channel")
            channel.add_on_close_callback(on_channel_closed)
            channel.add_on_return_callback(on_channel_return)
            channel.add_on_flow_callback(on_channel_flow)
            channel.add_on_cancel_callback(on_channel_cancel)
            self._channel = channel
            if self._exchange is not None:
                self._exchange_declare()
            else:
                self._queue_declare()

        connection.channel(on_open_callback=open_callback)

    def _exchange_declare(self):
        self.logger.info(f"Declaring exchange: {self._exchange}")

        self._channel.exchange_declare(
            callback=self._on_exchange_declared,
            exchange=self._exchange,
            exchange_type=self._exchange_type,
            durable=self._durable,
            auto_delete=self._auto_delete)

    def _on_exchange_declared(self, unframe):
        self.logger.info(f"Declared exchange: {self._exchange}")
        self._queue_declare()

    def _queue_declare(self):
        self.logger.info(f"Declaring queue: {self._queue}")

        self._channel.queue_declare(
            callback=self._on_queue_declared, queue=self._queue, durable=self._durable, auto_delete=self._auto_delete)

    def _on_queue_declared(self, method_frame):
        self.logger.info(f"Declared queue: {method_frame.method.queue}")
        self._queue = method_frame.method.queue
        if self._exchange is not None:
            self._queue_bind()
        else:
            self._on_setup_complete()

    def _queue_bind(self):
        self.logger.info(f"Binding queue: {self._queue} to exchange: {self._exchange}")
        self._channel.queue_bind(
            callback=self._on_queue_bind_ok, queue=self._queue, exchange=self._exchange, routing_key=self._routing_key)

    def _on_queue_bind_ok(self, unframe):
        self.logger.info(f"bound queue: {self._queue} to exchange: {self._exchange}")
        self._on_setup_complete()

    def _on_setup_complete(self):
        self._channel.basic_qos(prefetch_count=self._prefetch_count)
        self._channel_queue.put(self._channel)
        if self._should_consume:
            self._io_loop.call_later(0.01, self.consume, *self._consume_params)
コード例 #7
0
ファイル: async_task_manager.py プロジェクト: rydzykje/aucote
class AsyncTaskManager(object):
    """
    Aucote uses asynchronous task executed in ioloop. Some of them,
    especially scanners, should finish before ioloop will stop

    This class should be accessed by instance class method, which returns global instance of task manager

    """
    _instances = {}

    TASKS_POLITIC_WAIT = 0
    TASKS_POLITIC_KILL_WORKING_FIRST = 1
    TASKS_POLITIC_KILL_PROPORTIONS = 2
    TASKS_POLITIC_KILL_WORKING = 3

    def __init__(self, parallel_tasks=10):
        self._shutdown_condition = Event()
        self._stop_condition = Event()
        self._cron_tasks = {}
        self._parallel_tasks = parallel_tasks
        self._tasks = Queue()
        self._task_workers = {}
        self._events = {}
        self._limit = self._parallel_tasks
        self._next_task_number = 0
        self._toucan_keys = {}

    @classmethod
    def instance(cls, name=None, **kwargs):
        """
        Return instance of AsyncTaskManager

        Returns:
            AsyncTaskManager

        """
        if cls._instances.get(name) is None:
            cls._instances[name] = AsyncTaskManager(**kwargs)
        return cls._instances[name]

    @property
    def shutdown_condition(self):
        """
        Event which is resolved if every job is done and AsyncTaskManager is ready to shutdown

        Returns:
            Event
        """
        return self._shutdown_condition

    def start(self):
        """
        Start CronTabCallback tasks

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.start()

        for number in range(self._parallel_tasks):
            self._task_workers[number] = IOLoop.current().add_callback(
                partial(self.process_tasks, number))

        self._next_task_number = self._parallel_tasks

    def add_crontab_task(self, task, cron, event=None):
        """
        Add function to scheduler and execute at cron time

        Args:
            task (function):
            cron (str): crontab value
            event (Event): event which prevent from running task with similar aim, eg. security scans

        Returns:
            None

        """
        if event is not None:
            event = self._events.setdefault(event, Event())
        self._cron_tasks[task] = AsyncCrontabTask(cron, task, event)

    @gen.coroutine
    def stop(self):
        """
        Stop CronTabCallback tasks and wait on them to finish

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.stop()
        IOLoop.current().add_callback(self._prepare_shutdown)
        yield [self._stop_condition.wait(), self._tasks.join()]
        self._shutdown_condition.set()

    def _prepare_shutdown(self):
        """
        Check if ioloop can be stopped

        Returns:
            None

        """
        if any(task.is_running() for task in self._cron_tasks.values()):
            IOLoop.current().add_callback(self._prepare_shutdown)
            return

        self._stop_condition.set()

    def clear(self):
        """
        Clear list of tasks

        Returns:
            None

        """
        self._cron_tasks = {}
        self._shutdown_condition.clear()
        self._stop_condition.clear()

    async def process_tasks(self, number):
        """
        Execute queue. Every task in executed in separated thread (_Executor)

        """
        log.info("Starting worker %s", number)
        while True:
            try:
                item = self._tasks.get_nowait()
                try:
                    log.debug("Worker %s: starting %s", number, item)
                    thread = _Executor(task=item, number=number)
                    self._task_workers[number] = thread
                    thread.start()

                    while thread.is_alive():
                        await sleep(0.5)
                except:
                    log.exception("Worker %s: exception occurred", number)
                finally:
                    log.debug("Worker %s: %s finished", number, item)
                    self._tasks.task_done()
                    tasks_per_scan = (
                        '{}: {}'.format(scanner, len(tasks))
                        for scanner, tasks in self.tasks_by_scan.items())
                    log.debug("Tasks left in queue: %s (%s)",
                              self.unfinished_tasks, ', '.join(tasks_per_scan))
                    self._task_workers[number] = None
            except QueueEmpty:
                await gen.sleep(0.5)
                if self._stop_condition.is_set() and self._tasks.empty():
                    return
            finally:
                if self._limit < len(self._task_workers):
                    break

        del self._task_workers[number]

        log.info("Closing worker %s", number)

    def add_task(self, task):
        """
        Add task to the queue

        Args:
            task:

        Returns:
            None

        """
        self._tasks.put(task)

    @property
    def unfinished_tasks(self):
        """
        Task which are still processed or in queue

        Returns:
            int

        """
        return self._tasks._unfinished_tasks

    @property
    def tasks_by_scan(self):
        """
        Returns queued tasks grouped by scan
        """
        tasks = self._tasks._queue

        return_value = {}

        for task in tasks:
            return_value.setdefault(task.context.scanner.NAME, []).append(task)

        return return_value

    @property
    def cron_tasks(self):
        """
        List of cron tasks

        Returns:
            list

        """
        return self._cron_tasks.values()

    def cron_task(self, name):
        for task in self._cron_tasks.values():
            if task.func.NAME == name:
                return task

    def change_throttling_toucan(self, key, value):
        self.change_throttling(value)

    def change_throttling(self, new_value):
        """
        Change throttling value. Keeps throttling value between 0 and 1.

        Behaviour of algorithm is described in docs/throttling.md

        Only working tasks are closing here. Idle workers are stop by themselves

        """
        if new_value > 1:
            new_value = 1
        if new_value < 0:
            new_value = 0

        new_value = round(new_value * 100) / 100

        old_limit = self._limit
        self._limit = round(self._parallel_tasks * float(new_value))

        working_tasks = [
            number for number, task in self._task_workers.items()
            if task is not None
        ]
        current_tasks = len(self._task_workers)

        task_politic = cfg['service.scans.task_politic']

        if task_politic == self.TASKS_POLITIC_KILL_WORKING_FIRST:
            tasks_to_kill = current_tasks - self._limit
        elif task_politic == self.TASKS_POLITIC_KILL_PROPORTIONS:
            tasks_to_kill = round((old_limit - self._limit) *
                                  len(working_tasks) / self._parallel_tasks)
        elif task_politic == self.TASKS_POLITIC_KILL_WORKING:
            tasks_to_kill = (old_limit - self._limit) - (
                len(self._task_workers) - len(working_tasks))
        else:
            tasks_to_kill = 0

        log.debug('%s tasks will be killed', tasks_to_kill)

        for number in working_tasks:
            if tasks_to_kill <= 0:
                break
            self._task_workers[number].stop()
            tasks_to_kill -= 1

        self._limit = round(self._parallel_tasks * float(new_value))

        current_tasks = len(self._task_workers)

        for number in range(self._limit - current_tasks):
            self._task_workers[self._next_task_number] = None
            IOLoop.current().add_callback(
                partial(self.process_tasks, self._next_task_number))
            self._next_task_number += 1
コード例 #8
0
class AsynSpider(MySpider):
    def __init__(self, out, **kwargs):
        super(AsynSpider, self).__init__(out, **kwargs)
        self.client = httpclient.AsyncHTTPClient()
        self.q = Queue()
        self.fetching, self.fetched = set(), set()

    def assign_jobs(self, jobs):
        for job in jobs:
            self.q.put(job)

    @gen.coroutine
    def run(self):
        if self.q.empty():
            url = LIST_URL + urllib.urlencode(self.list_query)
            self.q.put(url)
        for _ in range(CONCURRENCY):
            self.worker()
        yield self.q.join()
        assert self.fetching == self.fetched
        # print len(self.fetched)
        if isinstance(self._out, Analysis):
            self._out.finish()

    @gen.coroutine
    def worker(self):
        while True:
            yield self.fetch_url()

    @gen.coroutine
    def fetch_url(self):
        current_url = yield self.q.get()
        try:
            if current_url in self.fetching:
                return
            self.fetching.add(current_url)
            request = httpclient.HTTPRequest(current_url, headers=HEADERS)
            resp = yield self.client.fetch(request)
            self.fetched.add(current_url)
            xml = etree.fromstring(resp.body)
            has_total_count = xml.xpath("//totalcount/text()")
            if has_total_count:  # 非空证明为列表,否则为详细页
                total_count = int(has_total_count[0])
                if total_count == 0:
                    return  # 列表跨界
                if self.list_query["pageno"] == 1:
                    pageno = 2
                    while pageno < 10:
                        # while pageno <= total_count / PAGE_SIZE:
                        self.list_query["pageno"] = pageno
                        next_list_url = LIST_URL + urllib.urlencode(
                            self.list_query)
                        self.q.put(next_list_url)
                        # logging.info(next_list_url)
                        pageno += 1
                job_ids = xml.xpath("//jobid/text()")
                job_detail_urls = []
                for ID in job_ids:
                    new_detail_query = DETAIL_QUERY.copy()
                    new_detail_query["jobid"] = ID
                    job_detail_urls.append(DETAIL_URL +
                                           urllib.urlencode(new_detail_query))
                for detail_url in job_detail_urls:
                    self.q.put(detail_url)
                    # logging.info(detail_url)

            else:
                self._out.collect(xml)
        finally:
            self.q.task_done()
コード例 #9
0
class Rx(PrettyPrintable):
    def __init__(self,
                 rx_tree,
                 session_id,
                 header_table=None,
                 io_loop=None,
                 service_name=None,
                 raw_headers=None,
                 trace_id=None):
        if header_table is None:
            header_table = CocaineHeaders()

        if io_loop:
            warnings.warn('io_loop argument is deprecated.',
                          DeprecationWarning)
        # If it's not the main thread
        # and a current IOloop doesn't exist here,
        # IOLoop.instance becomes self._io_loop
        self._io_loop = io_loop or IOLoop.current()
        self._queue = Queue()
        self._done = False
        self.session_id = session_id
        self.service_name = service_name
        self.rx_tree = rx_tree
        self.default_protocol = detect_protocol_type(rx_tree)
        self._headers = header_table
        self._current_headers = self._headers.merge(raw_headers)
        self.log = get_trace_adapter(log, trace_id)

    @coroutine
    def get(self, timeout=0, protocol=None):
        if self._done and self._queue.empty():
            raise ChokeEvent()

        # to pull various service errors
        if timeout <= 0:
            item = yield self._queue.get()
        else:
            deadline = datetime.timedelta(seconds=timeout)
            item = yield self._queue.get(deadline)

        if isinstance(item, Exception):
            raise item

        if protocol is None:
            protocol = self.default_protocol

        name, payload, raw_headers = item
        self._current_headers = self._headers.merge(raw_headers)
        res = protocol(name, payload)
        if isinstance(res, ProtocolError):
            raise ServiceError(self.service_name, res.reason, res.code,
                               res.category)
        else:
            raise Return(res)

    def done(self):
        self._done = True

    def push(self, msg_type, payload, raw_headers):
        dispatch = self.rx_tree.get(msg_type)
        self.log.debug("dispatch %s %.300s", dispatch, payload)
        if dispatch is None:
            raise InvalidMessageType(self.service_name,
                                     CocaineErrno.INVALIDMESSAGETYPE,
                                     "unexpected message type %s" % msg_type)
        name, rx = dispatch
        self.log.info("got message from `%s`: channel id: %s, type: %s",
                      self.service_name, self.session_id, name)
        self._queue.put_nowait((name, payload, raw_headers))
        if rx == {}:  # the last transition
            self.done()
        elif rx is not None:  # not a recursive transition
            self.rx_tree = rx

    def error(self, err):
        self._queue.put_nowait(err)

    def closed(self):
        return self._done

    def _format(self):
        return "name: %s, queue: %s, done: %s" % (self.service_name,
                                                  self._queue, self._done)

    @property
    def headers(self):
        return self._current_headers
コード例 #10
0
class Rx(PrettyPrintable):
    def __init__(self, rx_tree, session_id, header_table=None, io_loop=None, service_name=None,
                 raw_headers=None, trace_id=None):
        if header_table is None:
            header_table = CocaineHeaders()

        # If it's not the main thread
        # and a current IOloop doesn't exist here,
        # IOLoop.instance becomes self._io_loop
        self._io_loop = io_loop or IOLoop.current()
        self._queue = Queue()
        self._done = False
        self.session_id = session_id
        self.service_name = service_name
        self.rx_tree = rx_tree
        self.default_protocol = detect_protocol_type(rx_tree)
        self._headers = header_table
        self._current_headers = self._headers.merge(raw_headers)
        self.log = get_trace_adapter(log, trace_id)

    @coroutine
    def get(self, timeout=0, protocol=None):
        if self._done and self._queue.empty():
            raise ChokeEvent()

        # to pull various service errors
        if timeout <= 0:
            item = yield self._queue.get()
        else:
            deadline = datetime.timedelta(seconds=timeout)
            item = yield self._queue.get(deadline)

        if isinstance(item, Exception):
            raise item

        if protocol is None:
            protocol = self.default_protocol

        name, payload, raw_headers = item
        self._current_headers = self._headers.merge(raw_headers)
        res = protocol(name, payload)
        if isinstance(res, ProtocolError):
            raise ServiceError(self.service_name, res.reason, res.code, res.category)
        else:
            raise Return(res)

    def done(self):
        self._done = True

    def push(self, msg_type, payload, raw_headers):
        dispatch = self.rx_tree.get(msg_type)
        self.log.debug("dispatch %s %.300s", dispatch, payload)
        if dispatch is None:
            raise InvalidMessageType(self.service_name, CocaineErrno.INVALIDMESSAGETYPE,
                                     "unexpected message type %s" % msg_type)
        name, rx = dispatch
        self.log.info(
            "got message from `%s`: channel id: %s, type: %s",
            self.service_name,
            self.session_id,
            name
        )
        self._queue.put_nowait((name, payload, raw_headers))
        if rx == {}:  # the last transition
            self.done()
        elif rx is not None:  # not a recursive transition
            self.rx_tree = rx

    def error(self, err):
        self._queue.put_nowait(err)

    def closed(self):
        return self._done

    def _format(self):
        return "name: %s, queue: %s, done: %s" % (self.service_name, self._queue, self._done)

    @property
    def headers(self):
        return self._current_headers
コード例 #11
0
class AsyncConnection(object):
    def __init__(self, *args, **kwargs):
        kwargs["async"] = True

        if "thread_pool" in kwargs:
            self.__thread_pool = kwargs.pop("thread_pool")
        else:
            self.__thread_pool = futures.ThreadPoolExecutor(cpu_count())

        self.__connection = connect(*args, **kwargs)

        self.__io_loop = IOLoop.current()
        self.__connected = False

        log.debug("Trying to connect to postgresql")
        f = self.__wait()
        self.__io_loop.add_future(f, self.__on_connect)
        self.__queue = Queue()
        self.__has_active_cursor = False

        for method in ("get_backend_pid", "get_parameter_status"):
            setattr(self, method, self.__futurize(method))

    def __on_connect(self, result):
        log.debug("Connection establishment")
        self.__connected = True
        self.__io_loop.add_callback(self._loop)

    @coroutine
    def _loop(self):
        log.debug("Starting queue loop")
        while self.__connected:
            while self.__has_active_cursor or self.__connection.isexecuting():
                yield sleep(0.001)

            func, future = yield self.__queue.get()
            result = func()
            if isinstance(result, Future):
                result = yield result

            self.__io_loop.add_callback(future.set_result, result)
            yield self.__wait()

    @coroutine
    def __wait(self):
        log.debug("Waiting for events")
        while not (yield sleep(0.001)):
            try:
                state = self.__connection.poll()
            except QueryCanceledError:
                yield sleep(0.1)
                continue

            f = Future()

            def resolve(fileno, io_op):
                if f.running():
                    f.set_result(True)
                self.__io_loop.remove_handler(fileno)

            if state == psycopg2.extensions.POLL_OK:
                raise Return(True)

            elif state == psycopg2.extensions.POLL_READ:
                self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.READ)
                yield f

            elif state == psycopg2.extensions.POLL_WRITE:
                self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.WRITE)
                yield f

    def __on_cursor_open(self, cursor):
        self.__has_active_cursor = True
        log.debug("Opening cursor")

    def __on_cursor_close(self, cursor):
        self.__has_active_cursor = False
        log.debug("Closing active cursor")

    def cursor(self, **kwargs):
        f = Future()
        self.__io_loop.add_callback(
            self.__queue.put,
            (
                functools.partial(
                    AsyncCursor,
                    self.__connection,
                    self.__thread_pool,
                    self.__wait,
                    on_open=self.__on_cursor_open,
                    on_close=self.__on_cursor_close,
                    **kwargs
                ),
                f,
            ),
        )
        return f

    def cancel(self):
        return self.__thread_pool.submit(self.__connection.cancel)

    def close(self):
        self.__has_active_cursor = True

        @coroutine
        def closer():
            while not (yield self.__queue.empty()):
                func, future = yield self.__queue.get()
                future.set_exception(psycopg2.Error("Connection closed"))

            self.__io_loop.add_callback(self.__connection.close)

    def __futurize(self, item):
        attr = getattr(self.__connection, item)

        @functools.wraps(attr)
        def wrap(*args, **kwargs):
            f = Future()
            self.__io_loop.add_callback(self.__queue.put, (functools.partial(attr, *args, **kwargs), f))
            return f

        return wrap
コード例 #12
0
class Rx(PrettyPrintable):
    def __init__(self, rx_tree, io_loop=None, servicename=None):
        # If it's not the main thread
        # and a current IOloop doesn't exist here,
        # IOLoop.instance becomes self._io_loop
        self._io_loop = io_loop or IOLoop.current()
        self._queue = Queue()
        self._done = False
        self.servicename = servicename
        self.rx_tree = rx_tree
        self.default_protocol = detect_protocol_type(rx_tree)

    @coroutine
    def get(self, timeout=0, protocol=None):
        if self._done and self._queue.empty():
            raise ChokeEvent()

        # to pull variuos service errors
        if timeout <= 0 or timeout is None:
            item = yield self._queue.get()
        else:
            deadline = datetime.timedelta(seconds=timeout)
            item = yield self._queue.get(deadline)

        if isinstance(item, Exception):
            raise item

        if protocol is None:
            protocol = self.default_protocol

        name, payload = item
        res = protocol(name, payload)
        if isinstance(res, ProtocolError):
            raise ServiceError(self.servicename, res.reason,
                               res.code, res.category)
        else:
            raise Return(res)

    def done(self):
        self._done = True

    def push(self, msg_type, payload):
        dispatch = self.rx_tree.get(msg_type)
        log.debug("dispatch %s %.300s", dispatch, payload)
        if dispatch is None:
            raise InvalidMessageType(self.servicename, CocaineErrno.INVALIDMESSAGETYPE,
                                     "unexpected message type %s" % msg_type)
        name, rx = dispatch
        log.debug("name `%s` rx `%s`", name, rx)
        self._queue.put_nowait((name, payload))
        if rx == {}:  # the last transition
            self.done()
        elif rx is not None:  # not a recursive transition
            self.rx_tree = rx

    def error(self, err):
        self._queue.put_nowait(err)

    def closed(self):
        return self._done

    def _format(self):
        return "name: %s, queue: %s, done: %s" % (
            self.servicename, self._queue, self._done)
コード例 #13
0
class AsyncConnection(object):
    def __init__(self, *args, **kwargs):
        kwargs['async'] = True

        if "thread_pool" in kwargs:
            self.__thread_pool = kwargs.pop('thread_pool')
        else:
            self.__thread_pool = futures.ThreadPoolExecutor(cpu_count())

        self.__connection = connect(*args, **kwargs)

        self.__io_loop = IOLoop.current()
        self.__connected = False

        log.debug("Trying to connect to postgresql")
        f = self.__wait()
        self.__io_loop.add_future(f, self.__on_connect)
        self.__queue = Queue()
        self.__has_active_cursor = False

        for method in ('get_backend_pid', 'get_parameter_status'):
            setattr(self, method, self.__futurize(method))

    def __on_connect(self, result):
        log.debug("Connection establishment")
        self.__connected = True
        self.__io_loop.add_callback(self._loop)

    @coroutine
    def _loop(self):
        log.debug("Starting queue loop")
        while self.__connected:
            while self.__has_active_cursor or self.__connection.isexecuting():
                yield sleep(0.001)

            func, future = yield self.__queue.get()
            result = func()
            if isinstance(result, Future):
                result = yield result

            self.__io_loop.add_callback(future.set_result, result)
            yield self.__wait()

    @coroutine
    def __wait(self):
        log.debug("Waiting for events")
        while not (yield sleep(0.001)):
            try:
                state = self.__connection.poll()
            except QueryCanceledError:
                yield sleep(0.1)
                continue

            f = Future()

            def resolve(fileno, io_op):
                if f.running():
                    f.set_result(True)
                self.__io_loop.remove_handler(fileno)

            if state == psycopg2.extensions.POLL_OK:
                raise Return(True)

            elif state == psycopg2.extensions.POLL_READ:
                self.__io_loop.add_handler(self.__connection.fileno(), resolve,
                                           IOLoop.READ)
                yield f

            elif state == psycopg2.extensions.POLL_WRITE:
                self.__io_loop.add_handler(self.__connection.fileno(), resolve,
                                           IOLoop.WRITE)
                yield f

    def __on_cursor_open(self, cursor):
        self.__has_active_cursor = True
        log.debug('Opening cursor')

    def __on_cursor_close(self, cursor):
        self.__has_active_cursor = False
        log.debug('Closing active cursor')

    def cursor(self, **kwargs):
        f = Future()
        self.__io_loop.add_callback(
            self.__queue.put,
            (functools.partial(AsyncCursor,
                               self.__connection,
                               self.__thread_pool,
                               self.__wait,
                               on_open=self.__on_cursor_open,
                               on_close=self.__on_cursor_close,
                               **kwargs), f))
        return f

    def cancel(self):
        return self.__thread_pool.submit(self.__connection.cancel)

    def close(self):
        self.__has_active_cursor = True

        @coroutine
        def closer():
            while not (yield self.__queue.empty()):
                func, future = yield self.__queue.get()
                future.set_exception(psycopg2.Error("Connection closed"))

            self.__io_loop.add_callback(self.__connection.close)

    def __futurize(self, item):
        attr = getattr(self.__connection, item)

        @functools.wraps(attr)
        def wrap(*args, **kwargs):
            f = Future()
            self.__io_loop.add_callback(
                self.__queue.put,
                (functools.partial(attr, *args, **kwargs), f))
            return f

        return wrap