class EventSource(RequestHandler): def initialize(self, stream): #assert isinstance(stream, Stream) self.stream = stream self.messages = Queue() self.finished = False self.set_header('content-type', 'text/event-stream') self.set_header('cache-control', 'no-cache') self.store = self.stream.sink(self.messages.put) @gen.coroutine def publish(self, message): """Pushes data to a listener.""" try: self.write(message >> to_str) yield self.flush() except StreamClosedError: self.finished = True (self.request.remote_ip, StreamClosedError) >> log @gen.coroutine def get(self, *args, **kwargs): try: while not self.finished: message = yield self.messages.get() yield self.publish(message) except Exception: pass finally: self.store.destroy() self.messages.empty() self.finish()
class gather(Stream): def __init__(self, child, limit=10, client=None): self.client = client or default_client() self.queue = Queue(maxsize=limit) self.condition = Condition() Stream.__init__(self, child) self.client.loop.add_callback(self.cb) def update(self, x, who=None): return self.queue.put(x) @gen.coroutine def cb(self): while True: x = yield self.queue.get() L = [x] while not self.queue.empty(): L.append(self.queue.get_nowait()) results = yield self.client._gather(L) for x in results: yield self.emit(x) if self.queue.empty(): self.condition.notify_all() @gen.coroutine def flush(self): while not self.queue.empty(): yield self.condition.wait()
class ConnectionPool(object): def __init__(self, servers, maxsize=15, minsize=1, loop=None, debug=0): loop = loop if loop is not None else tornado.ioloop.IOLoop.instance() if debug: logging.basicConfig( level=logging.DEBUG, format="'%(levelname)s %(asctime)s" " %(module)s:%(lineno)d %(process)d %(thread)d %(message)s'") self._loop = loop self._servers = servers self._minsize = minsize self._debug = debug self._in_use = set() self._pool = Queue(maxsize) @gen.coroutine def clear(self): """Clear pool connections.""" while not self._pool.empty(): conn = yield self._pool.get() conn.close_socket() def size(self): return len(self._in_use) + self._pool.qsize() @gen.coroutine def acquire(self): """Acquire connection from the pool, or spawn new one if pool maxsize permits. :return: ``Connetion`` (reader, writer) """ while self.size() < self._minsize: _conn = yield self._create_new_conn() yield self._pool.put(_conn) conn = None while not conn: if not self._pool.empty(): conn = yield self._pool.get() if conn is None: conn = yield self._create_new_conn() self._in_use.add(conn) raise gen.Return(conn) @gen.coroutine def _create_new_conn(self): conn = yield Connection.get_conn(self._servers, self._debug) raise gen.Return(conn) def release(self, conn): self._in_use.remove(conn) try: self._pool.put_nowait(conn) except (QueueEmpty, QueueFull): conn.close_socket()
class BatchedStream(object): """ Mostly obsolete, see BatchedSend """ def __init__(self, stream, interval): self.stream = stream self.interval = interval / 1000.0 self.last_transmission = default_timer() self.send_q = Queue() self.recv_q = Queue() self._background_send_coroutine = self._background_send() self._background_recv_coroutine = self._background_recv() self._broken = None self.pc = PeriodicCallback(lambda: None, 100) self.pc.start() @gen.coroutine def _background_send(self): with log_errors(): while True: msg = yield self.send_q.get() if msg == "close": break msgs = [msg] now = default_timer() wait_time = self.last_transmission + self.interval - now if wait_time > 0: yield gen.sleep(wait_time) while not self.send_q.empty(): msgs.append(self.send_q.get_nowait()) try: yield write(self.stream, msgs) except StreamClosedError: self.recv_q.put_nowait("close") self._broken = True break if len(msgs) > 1: logger.debug("Batched messages: %d", len(msgs)) for _ in msgs: self.send_q.task_done() @gen.coroutine def _background_recv(self): with log_errors(): while True: try: msgs = yield read(self.stream) except StreamClosedError: self.recv_q.put_nowait("close") self.send_q.put_nowait("close") self._broken = True break assert isinstance(msgs, list) if len(msgs) > 1: logger.debug("Batched messages: %d", len(msgs)) for msg in msgs: self.recv_q.put_nowait(msg) @gen.coroutine def flush(self): yield self.send_q.join() @gen.coroutine def send(self, msg): if self._broken: raise StreamClosedError("Batch Stream is Closed") else: self.send_q.put_nowait(msg) @gen.coroutine def recv(self): result = yield self.recv_q.get() if result == "close": raise StreamClosedError("Batched Stream is Closed") else: raise gen.Return(result) @gen.coroutine def close(self): yield self.flush() raise gen.Return(self.stream.close()) def closed(self): return self.stream.closed()
class BatchedStream(object): """ Mostly obsolete, see BatchedSend """ def __init__(self, stream, interval): self.stream = stream self.interval = interval / 1000. self.last_transmission = default_timer() self.send_q = Queue() self.recv_q = Queue() self._background_send_coroutine = self._background_send() self._background_recv_coroutine = self._background_recv() self._broken = None self.pc = PeriodicCallback(lambda: None, 100) self.pc.start() @gen.coroutine def _background_send(self): with log_errors(): while True: msg = yield self.send_q.get() if msg == 'close': break msgs = [msg] now = default_timer() wait_time = self.last_transmission + self.interval - now if wait_time > 0: yield gen.sleep(wait_time) while not self.send_q.empty(): msgs.append(self.send_q.get_nowait()) try: yield write(self.stream, msgs) except StreamClosedError: self.recv_q.put_nowait('close') self._broken = True break if len(msgs) > 1: logger.debug("Batched messages: %d", len(msgs)) for _ in msgs: self.send_q.task_done() @gen.coroutine def _background_recv(self): with log_errors(): while True: try: msgs = yield read(self.stream) except StreamClosedError: self.recv_q.put_nowait('close') self.send_q.put_nowait('close') self._broken = True break assert isinstance(msgs, list) if len(msgs) > 1: logger.debug("Batched messages: %d", len(msgs)) for msg in msgs: self.recv_q.put_nowait(msg) @gen.coroutine def flush(self): yield self.send_q.join() @gen.coroutine def send(self, msg): if self._broken: raise StreamClosedError('Batch Stream is Closed') else: self.send_q.put_nowait(msg) @gen.coroutine def recv(self): result = yield self.recv_q.get() if result == 'close': raise StreamClosedError('Batched Stream is Closed') else: raise gen.Return(result) @gen.coroutine def close(self): yield self.flush() raise gen.Return(self.stream.close()) def closed(self): return self.stream.closed()
class ChannelConfiguration: _USER_CLOSE_CODE = 0 _NORMAL_CLOSE_CODE = 200 _NO_ROUTE_CODE = 312 logger = logger connection: AsyncConnection def __init__(self, connection: AsyncConnection, io_loop, exchange=None, exchange_type=None, queue=None, routing_key=None, durable=False, auto_delete=False, prefetch_count=None): self._channel = None self.connection = connection self._io_loop = io_loop self._channel_queue = Queue(maxsize=1) self._queue = queue if queue is not None else "" self._exchange = exchange if exchange_type is None: exchange_type = 'topic' self._exchange_type = exchange_type self._routing_key = routing_key self._durable = durable self._auto_delete = auto_delete if prefetch_count is None: prefetch_count = 1 self._prefetch_count = prefetch_count self._should_consume = False self._consume_params = dict() @gen.coroutine def consume(self, on_message_callback, handler=None, no_ack=False): self.logger.info(f"[start consuming] routing key: {self._routing_key}; queue name: {self._queue}") channel = yield self._get_channel() self._should_consume = True self._consume_params = [on_message_callback, handler, no_ack] if handler is not None: channel.basic_consume( queue=self._queue, auto_ack=no_ack, on_message_callback=functools.partial( on_message_callback, handler=handler ) ) else: channel.basic_consume(queue=self._queue, on_message_callback=on_message_callback, auto_ack=no_ack) @gen.coroutine def publish(self, body, mandatory=None, properties=None, reply_to=None): channel = yield self._get_channel() if reply_to is not None: exchange = "" routing_key = reply_to else: exchange = self._exchange routing_key = self._routing_key self.logger.info(f"Publishing message. exchange: {exchange}; routing_key: {routing_key}") channel.basic_publish(exchange=exchange, routing_key=routing_key, body=body, mandatory=mandatory, properties=properties) @gen.coroutine def _get_channel(self): if self._channel_queue.empty(): yield self._create_channel() channel = yield self._top() return channel @gen.coroutine def _top(self): channel = yield self._channel_queue.get() self._channel_queue.put(channel) return channel def _remove_channel_from_queue(self): try: self._channel_queue.get_nowait() except QueueEmpty: pass @gen.coroutine def _create_channel(self): self.logger.info("creating channel") connection = yield self.connection.get_connection() def on_channel_flow(*args, **kwargs): pass def on_channel_cancel(frame): self.logger.error("Channel was canceled") if not self._channel_queue.empty(): channel = self._channel if channel and not channel.is_close or channel.is_closing: channel.close() def on_channel_closed(channel, reason): reply_code, reply_txt = reason.args self.logger.info(f'Channel {channel} was closed: {reason}') if reply_code not in [self._NORMAL_CLOSE_CODE, self._USER_CLOSE_CODE]: self.logger.error(f"Channel closed. reply code: {reply_code}; reply text: {reply_txt}. " f"System will exist") if connection and not (connection.is_closed or connection.is_closing): connection.close() self._remove_channel_from_queue() self._io_loop.call_later(1, self._create_channel) else: self.logger.info(f"Reply code: {reply_code}, reply text: {reply_txt}") def on_channel_return(channel, method, property, body): """"If publish message has failed, this method will be invoked.""" self.logger.error(f"Rejected from server. reply code: {method.reply_code}, reply text: {method.reply_txt}") raise Exception("Failed to publish message.") def open_callback(channel): self.logger.info("Created channel") channel.add_on_close_callback(on_channel_closed) channel.add_on_return_callback(on_channel_return) channel.add_on_flow_callback(on_channel_flow) channel.add_on_cancel_callback(on_channel_cancel) self._channel = channel if self._exchange is not None: self._exchange_declare() else: self._queue_declare() connection.channel(on_open_callback=open_callback) def _exchange_declare(self): self.logger.info(f"Declaring exchange: {self._exchange}") self._channel.exchange_declare( callback=self._on_exchange_declared, exchange=self._exchange, exchange_type=self._exchange_type, durable=self._durable, auto_delete=self._auto_delete) def _on_exchange_declared(self, unframe): self.logger.info(f"Declared exchange: {self._exchange}") self._queue_declare() def _queue_declare(self): self.logger.info(f"Declaring queue: {self._queue}") self._channel.queue_declare( callback=self._on_queue_declared, queue=self._queue, durable=self._durable, auto_delete=self._auto_delete) def _on_queue_declared(self, method_frame): self.logger.info(f"Declared queue: {method_frame.method.queue}") self._queue = method_frame.method.queue if self._exchange is not None: self._queue_bind() else: self._on_setup_complete() def _queue_bind(self): self.logger.info(f"Binding queue: {self._queue} to exchange: {self._exchange}") self._channel.queue_bind( callback=self._on_queue_bind_ok, queue=self._queue, exchange=self._exchange, routing_key=self._routing_key) def _on_queue_bind_ok(self, unframe): self.logger.info(f"bound queue: {self._queue} to exchange: {self._exchange}") self._on_setup_complete() def _on_setup_complete(self): self._channel.basic_qos(prefetch_count=self._prefetch_count) self._channel_queue.put(self._channel) if self._should_consume: self._io_loop.call_later(0.01, self.consume, *self._consume_params)
class AsyncTaskManager(object): """ Aucote uses asynchronous task executed in ioloop. Some of them, especially scanners, should finish before ioloop will stop This class should be accessed by instance class method, which returns global instance of task manager """ _instances = {} TASKS_POLITIC_WAIT = 0 TASKS_POLITIC_KILL_WORKING_FIRST = 1 TASKS_POLITIC_KILL_PROPORTIONS = 2 TASKS_POLITIC_KILL_WORKING = 3 def __init__(self, parallel_tasks=10): self._shutdown_condition = Event() self._stop_condition = Event() self._cron_tasks = {} self._parallel_tasks = parallel_tasks self._tasks = Queue() self._task_workers = {} self._events = {} self._limit = self._parallel_tasks self._next_task_number = 0 self._toucan_keys = {} @classmethod def instance(cls, name=None, **kwargs): """ Return instance of AsyncTaskManager Returns: AsyncTaskManager """ if cls._instances.get(name) is None: cls._instances[name] = AsyncTaskManager(**kwargs) return cls._instances[name] @property def shutdown_condition(self): """ Event which is resolved if every job is done and AsyncTaskManager is ready to shutdown Returns: Event """ return self._shutdown_condition def start(self): """ Start CronTabCallback tasks Returns: None """ for task in self._cron_tasks.values(): task.start() for number in range(self._parallel_tasks): self._task_workers[number] = IOLoop.current().add_callback( partial(self.process_tasks, number)) self._next_task_number = self._parallel_tasks def add_crontab_task(self, task, cron, event=None): """ Add function to scheduler and execute at cron time Args: task (function): cron (str): crontab value event (Event): event which prevent from running task with similar aim, eg. security scans Returns: None """ if event is not None: event = self._events.setdefault(event, Event()) self._cron_tasks[task] = AsyncCrontabTask(cron, task, event) @gen.coroutine def stop(self): """ Stop CronTabCallback tasks and wait on them to finish Returns: None """ for task in self._cron_tasks.values(): task.stop() IOLoop.current().add_callback(self._prepare_shutdown) yield [self._stop_condition.wait(), self._tasks.join()] self._shutdown_condition.set() def _prepare_shutdown(self): """ Check if ioloop can be stopped Returns: None """ if any(task.is_running() for task in self._cron_tasks.values()): IOLoop.current().add_callback(self._prepare_shutdown) return self._stop_condition.set() def clear(self): """ Clear list of tasks Returns: None """ self._cron_tasks = {} self._shutdown_condition.clear() self._stop_condition.clear() async def process_tasks(self, number): """ Execute queue. Every task in executed in separated thread (_Executor) """ log.info("Starting worker %s", number) while True: try: item = self._tasks.get_nowait() try: log.debug("Worker %s: starting %s", number, item) thread = _Executor(task=item, number=number) self._task_workers[number] = thread thread.start() while thread.is_alive(): await sleep(0.5) except: log.exception("Worker %s: exception occurred", number) finally: log.debug("Worker %s: %s finished", number, item) self._tasks.task_done() tasks_per_scan = ( '{}: {}'.format(scanner, len(tasks)) for scanner, tasks in self.tasks_by_scan.items()) log.debug("Tasks left in queue: %s (%s)", self.unfinished_tasks, ', '.join(tasks_per_scan)) self._task_workers[number] = None except QueueEmpty: await gen.sleep(0.5) if self._stop_condition.is_set() and self._tasks.empty(): return finally: if self._limit < len(self._task_workers): break del self._task_workers[number] log.info("Closing worker %s", number) def add_task(self, task): """ Add task to the queue Args: task: Returns: None """ self._tasks.put(task) @property def unfinished_tasks(self): """ Task which are still processed or in queue Returns: int """ return self._tasks._unfinished_tasks @property def tasks_by_scan(self): """ Returns queued tasks grouped by scan """ tasks = self._tasks._queue return_value = {} for task in tasks: return_value.setdefault(task.context.scanner.NAME, []).append(task) return return_value @property def cron_tasks(self): """ List of cron tasks Returns: list """ return self._cron_tasks.values() def cron_task(self, name): for task in self._cron_tasks.values(): if task.func.NAME == name: return task def change_throttling_toucan(self, key, value): self.change_throttling(value) def change_throttling(self, new_value): """ Change throttling value. Keeps throttling value between 0 and 1. Behaviour of algorithm is described in docs/throttling.md Only working tasks are closing here. Idle workers are stop by themselves """ if new_value > 1: new_value = 1 if new_value < 0: new_value = 0 new_value = round(new_value * 100) / 100 old_limit = self._limit self._limit = round(self._parallel_tasks * float(new_value)) working_tasks = [ number for number, task in self._task_workers.items() if task is not None ] current_tasks = len(self._task_workers) task_politic = cfg['service.scans.task_politic'] if task_politic == self.TASKS_POLITIC_KILL_WORKING_FIRST: tasks_to_kill = current_tasks - self._limit elif task_politic == self.TASKS_POLITIC_KILL_PROPORTIONS: tasks_to_kill = round((old_limit - self._limit) * len(working_tasks) / self._parallel_tasks) elif task_politic == self.TASKS_POLITIC_KILL_WORKING: tasks_to_kill = (old_limit - self._limit) - ( len(self._task_workers) - len(working_tasks)) else: tasks_to_kill = 0 log.debug('%s tasks will be killed', tasks_to_kill) for number in working_tasks: if tasks_to_kill <= 0: break self._task_workers[number].stop() tasks_to_kill -= 1 self._limit = round(self._parallel_tasks * float(new_value)) current_tasks = len(self._task_workers) for number in range(self._limit - current_tasks): self._task_workers[self._next_task_number] = None IOLoop.current().add_callback( partial(self.process_tasks, self._next_task_number)) self._next_task_number += 1
class AsynSpider(MySpider): def __init__(self, out, **kwargs): super(AsynSpider, self).__init__(out, **kwargs) self.client = httpclient.AsyncHTTPClient() self.q = Queue() self.fetching, self.fetched = set(), set() def assign_jobs(self, jobs): for job in jobs: self.q.put(job) @gen.coroutine def run(self): if self.q.empty(): url = LIST_URL + urllib.urlencode(self.list_query) self.q.put(url) for _ in range(CONCURRENCY): self.worker() yield self.q.join() assert self.fetching == self.fetched # print len(self.fetched) if isinstance(self._out, Analysis): self._out.finish() @gen.coroutine def worker(self): while True: yield self.fetch_url() @gen.coroutine def fetch_url(self): current_url = yield self.q.get() try: if current_url in self.fetching: return self.fetching.add(current_url) request = httpclient.HTTPRequest(current_url, headers=HEADERS) resp = yield self.client.fetch(request) self.fetched.add(current_url) xml = etree.fromstring(resp.body) has_total_count = xml.xpath("//totalcount/text()") if has_total_count: # 非空证明为列表,否则为详细页 total_count = int(has_total_count[0]) if total_count == 0: return # 列表跨界 if self.list_query["pageno"] == 1: pageno = 2 while pageno < 10: # while pageno <= total_count / PAGE_SIZE: self.list_query["pageno"] = pageno next_list_url = LIST_URL + urllib.urlencode( self.list_query) self.q.put(next_list_url) # logging.info(next_list_url) pageno += 1 job_ids = xml.xpath("//jobid/text()") job_detail_urls = [] for ID in job_ids: new_detail_query = DETAIL_QUERY.copy() new_detail_query["jobid"] = ID job_detail_urls.append(DETAIL_URL + urllib.urlencode(new_detail_query)) for detail_url in job_detail_urls: self.q.put(detail_url) # logging.info(detail_url) else: self._out.collect(xml) finally: self.q.task_done()
class Rx(PrettyPrintable): def __init__(self, rx_tree, session_id, header_table=None, io_loop=None, service_name=None, raw_headers=None, trace_id=None): if header_table is None: header_table = CocaineHeaders() if io_loop: warnings.warn('io_loop argument is deprecated.', DeprecationWarning) # If it's not the main thread # and a current IOloop doesn't exist here, # IOLoop.instance becomes self._io_loop self._io_loop = io_loop or IOLoop.current() self._queue = Queue() self._done = False self.session_id = session_id self.service_name = service_name self.rx_tree = rx_tree self.default_protocol = detect_protocol_type(rx_tree) self._headers = header_table self._current_headers = self._headers.merge(raw_headers) self.log = get_trace_adapter(log, trace_id) @coroutine def get(self, timeout=0, protocol=None): if self._done and self._queue.empty(): raise ChokeEvent() # to pull various service errors if timeout <= 0: item = yield self._queue.get() else: deadline = datetime.timedelta(seconds=timeout) item = yield self._queue.get(deadline) if isinstance(item, Exception): raise item if protocol is None: protocol = self.default_protocol name, payload, raw_headers = item self._current_headers = self._headers.merge(raw_headers) res = protocol(name, payload) if isinstance(res, ProtocolError): raise ServiceError(self.service_name, res.reason, res.code, res.category) else: raise Return(res) def done(self): self._done = True def push(self, msg_type, payload, raw_headers): dispatch = self.rx_tree.get(msg_type) self.log.debug("dispatch %s %.300s", dispatch, payload) if dispatch is None: raise InvalidMessageType(self.service_name, CocaineErrno.INVALIDMESSAGETYPE, "unexpected message type %s" % msg_type) name, rx = dispatch self.log.info("got message from `%s`: channel id: %s, type: %s", self.service_name, self.session_id, name) self._queue.put_nowait((name, payload, raw_headers)) if rx == {}: # the last transition self.done() elif rx is not None: # not a recursive transition self.rx_tree = rx def error(self, err): self._queue.put_nowait(err) def closed(self): return self._done def _format(self): return "name: %s, queue: %s, done: %s" % (self.service_name, self._queue, self._done) @property def headers(self): return self._current_headers
class Rx(PrettyPrintable): def __init__(self, rx_tree, session_id, header_table=None, io_loop=None, service_name=None, raw_headers=None, trace_id=None): if header_table is None: header_table = CocaineHeaders() # If it's not the main thread # and a current IOloop doesn't exist here, # IOLoop.instance becomes self._io_loop self._io_loop = io_loop or IOLoop.current() self._queue = Queue() self._done = False self.session_id = session_id self.service_name = service_name self.rx_tree = rx_tree self.default_protocol = detect_protocol_type(rx_tree) self._headers = header_table self._current_headers = self._headers.merge(raw_headers) self.log = get_trace_adapter(log, trace_id) @coroutine def get(self, timeout=0, protocol=None): if self._done and self._queue.empty(): raise ChokeEvent() # to pull various service errors if timeout <= 0: item = yield self._queue.get() else: deadline = datetime.timedelta(seconds=timeout) item = yield self._queue.get(deadline) if isinstance(item, Exception): raise item if protocol is None: protocol = self.default_protocol name, payload, raw_headers = item self._current_headers = self._headers.merge(raw_headers) res = protocol(name, payload) if isinstance(res, ProtocolError): raise ServiceError(self.service_name, res.reason, res.code, res.category) else: raise Return(res) def done(self): self._done = True def push(self, msg_type, payload, raw_headers): dispatch = self.rx_tree.get(msg_type) self.log.debug("dispatch %s %.300s", dispatch, payload) if dispatch is None: raise InvalidMessageType(self.service_name, CocaineErrno.INVALIDMESSAGETYPE, "unexpected message type %s" % msg_type) name, rx = dispatch self.log.info( "got message from `%s`: channel id: %s, type: %s", self.service_name, self.session_id, name ) self._queue.put_nowait((name, payload, raw_headers)) if rx == {}: # the last transition self.done() elif rx is not None: # not a recursive transition self.rx_tree = rx def error(self, err): self._queue.put_nowait(err) def closed(self): return self._done def _format(self): return "name: %s, queue: %s, done: %s" % (self.service_name, self._queue, self._done) @property def headers(self): return self._current_headers
class AsyncConnection(object): def __init__(self, *args, **kwargs): kwargs["async"] = True if "thread_pool" in kwargs: self.__thread_pool = kwargs.pop("thread_pool") else: self.__thread_pool = futures.ThreadPoolExecutor(cpu_count()) self.__connection = connect(*args, **kwargs) self.__io_loop = IOLoop.current() self.__connected = False log.debug("Trying to connect to postgresql") f = self.__wait() self.__io_loop.add_future(f, self.__on_connect) self.__queue = Queue() self.__has_active_cursor = False for method in ("get_backend_pid", "get_parameter_status"): setattr(self, method, self.__futurize(method)) def __on_connect(self, result): log.debug("Connection establishment") self.__connected = True self.__io_loop.add_callback(self._loop) @coroutine def _loop(self): log.debug("Starting queue loop") while self.__connected: while self.__has_active_cursor or self.__connection.isexecuting(): yield sleep(0.001) func, future = yield self.__queue.get() result = func() if isinstance(result, Future): result = yield result self.__io_loop.add_callback(future.set_result, result) yield self.__wait() @coroutine def __wait(self): log.debug("Waiting for events") while not (yield sleep(0.001)): try: state = self.__connection.poll() except QueryCanceledError: yield sleep(0.1) continue f = Future() def resolve(fileno, io_op): if f.running(): f.set_result(True) self.__io_loop.remove_handler(fileno) if state == psycopg2.extensions.POLL_OK: raise Return(True) elif state == psycopg2.extensions.POLL_READ: self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.READ) yield f elif state == psycopg2.extensions.POLL_WRITE: self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.WRITE) yield f def __on_cursor_open(self, cursor): self.__has_active_cursor = True log.debug("Opening cursor") def __on_cursor_close(self, cursor): self.__has_active_cursor = False log.debug("Closing active cursor") def cursor(self, **kwargs): f = Future() self.__io_loop.add_callback( self.__queue.put, ( functools.partial( AsyncCursor, self.__connection, self.__thread_pool, self.__wait, on_open=self.__on_cursor_open, on_close=self.__on_cursor_close, **kwargs ), f, ), ) return f def cancel(self): return self.__thread_pool.submit(self.__connection.cancel) def close(self): self.__has_active_cursor = True @coroutine def closer(): while not (yield self.__queue.empty()): func, future = yield self.__queue.get() future.set_exception(psycopg2.Error("Connection closed")) self.__io_loop.add_callback(self.__connection.close) def __futurize(self, item): attr = getattr(self.__connection, item) @functools.wraps(attr) def wrap(*args, **kwargs): f = Future() self.__io_loop.add_callback(self.__queue.put, (functools.partial(attr, *args, **kwargs), f)) return f return wrap
class Rx(PrettyPrintable): def __init__(self, rx_tree, io_loop=None, servicename=None): # If it's not the main thread # and a current IOloop doesn't exist here, # IOLoop.instance becomes self._io_loop self._io_loop = io_loop or IOLoop.current() self._queue = Queue() self._done = False self.servicename = servicename self.rx_tree = rx_tree self.default_protocol = detect_protocol_type(rx_tree) @coroutine def get(self, timeout=0, protocol=None): if self._done and self._queue.empty(): raise ChokeEvent() # to pull variuos service errors if timeout <= 0 or timeout is None: item = yield self._queue.get() else: deadline = datetime.timedelta(seconds=timeout) item = yield self._queue.get(deadline) if isinstance(item, Exception): raise item if protocol is None: protocol = self.default_protocol name, payload = item res = protocol(name, payload) if isinstance(res, ProtocolError): raise ServiceError(self.servicename, res.reason, res.code, res.category) else: raise Return(res) def done(self): self._done = True def push(self, msg_type, payload): dispatch = self.rx_tree.get(msg_type) log.debug("dispatch %s %.300s", dispatch, payload) if dispatch is None: raise InvalidMessageType(self.servicename, CocaineErrno.INVALIDMESSAGETYPE, "unexpected message type %s" % msg_type) name, rx = dispatch log.debug("name `%s` rx `%s`", name, rx) self._queue.put_nowait((name, payload)) if rx == {}: # the last transition self.done() elif rx is not None: # not a recursive transition self.rx_tree = rx def error(self, err): self._queue.put_nowait(err) def closed(self): return self._done def _format(self): return "name: %s, queue: %s, done: %s" % ( self.servicename, self._queue, self._done)
class AsyncConnection(object): def __init__(self, *args, **kwargs): kwargs['async'] = True if "thread_pool" in kwargs: self.__thread_pool = kwargs.pop('thread_pool') else: self.__thread_pool = futures.ThreadPoolExecutor(cpu_count()) self.__connection = connect(*args, **kwargs) self.__io_loop = IOLoop.current() self.__connected = False log.debug("Trying to connect to postgresql") f = self.__wait() self.__io_loop.add_future(f, self.__on_connect) self.__queue = Queue() self.__has_active_cursor = False for method in ('get_backend_pid', 'get_parameter_status'): setattr(self, method, self.__futurize(method)) def __on_connect(self, result): log.debug("Connection establishment") self.__connected = True self.__io_loop.add_callback(self._loop) @coroutine def _loop(self): log.debug("Starting queue loop") while self.__connected: while self.__has_active_cursor or self.__connection.isexecuting(): yield sleep(0.001) func, future = yield self.__queue.get() result = func() if isinstance(result, Future): result = yield result self.__io_loop.add_callback(future.set_result, result) yield self.__wait() @coroutine def __wait(self): log.debug("Waiting for events") while not (yield sleep(0.001)): try: state = self.__connection.poll() except QueryCanceledError: yield sleep(0.1) continue f = Future() def resolve(fileno, io_op): if f.running(): f.set_result(True) self.__io_loop.remove_handler(fileno) if state == psycopg2.extensions.POLL_OK: raise Return(True) elif state == psycopg2.extensions.POLL_READ: self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.READ) yield f elif state == psycopg2.extensions.POLL_WRITE: self.__io_loop.add_handler(self.__connection.fileno(), resolve, IOLoop.WRITE) yield f def __on_cursor_open(self, cursor): self.__has_active_cursor = True log.debug('Opening cursor') def __on_cursor_close(self, cursor): self.__has_active_cursor = False log.debug('Closing active cursor') def cursor(self, **kwargs): f = Future() self.__io_loop.add_callback( self.__queue.put, (functools.partial(AsyncCursor, self.__connection, self.__thread_pool, self.__wait, on_open=self.__on_cursor_open, on_close=self.__on_cursor_close, **kwargs), f)) return f def cancel(self): return self.__thread_pool.submit(self.__connection.cancel) def close(self): self.__has_active_cursor = True @coroutine def closer(): while not (yield self.__queue.empty()): func, future = yield self.__queue.get() future.set_exception(psycopg2.Error("Connection closed")) self.__io_loop.add_callback(self.__connection.close) def __futurize(self, item): attr = getattr(self.__connection, item) @functools.wraps(attr) def wrap(*args, **kwargs): f = Future() self.__io_loop.add_callback( self.__queue.put, (functools.partial(attr, *args, **kwargs), f)) return f return wrap