コード例 #1
0
ファイル: asyncio.py プロジェクト: aaronlifshin/whysaurus
    def wait(self, promise, timeout=None):
        e = Event()

        def on_resolve_or_reject(_):
            e.set()

        promise._then(on_resolve_or_reject, on_resolve_or_reject)

        # We can't use the timeout in Asyncio event
        e.wait()
コード例 #2
0
ファイル: coro.py プロジェクト: vivisect/synapse
async def event_wait(event: asyncio.Event, timeout=None):
    '''
    Wait on an an asyncio event with an optional timeout

    Returns:
        true if the event got set, None if timed out
    '''
    if timeout is None:
        await event.wait()
        return True

    try:
        await asyncio.wait_for(event.wait(), timeout)
    except asyncio.TimeoutError:
        return False
    return True
コード例 #3
0
ファイル: actor.py プロジェクト: iakinsey/illume
class Actor(object):

    """
    Main actor model.

    Args:
        inbox (GeneratorQueue): Inbox to consume from.
        outbox (GeneratorQueue): Outbox to publish to.
        loop (GeneratorQueue): Event loop.
    """

    running = False
    _force_stop = False

    def __init__(self, inbox, outbox, loop=None):
        self.inbox = inbox
        self.outbox = outbox

        if not loop:
            loop = get_event_loop()

        self._loop = loop
        self._pause_lock = Lock(loop=self._loop)
        self._stop_event = Event(loop=self._loop)
        self._test = None
        self.__testy = None

        self.on_init()

    @property
    def paused(self):
        """Indicate if actor is paused."""
        return self._pause_lock.locked()

    async def start(self):
        """Main public entry point to start the actor."""
        await self.initialize()
        await self._start()

    async def initialize(self):
        """Initialize the actor before starting."""
        await self.on_start()

        if self._force_stop:
            return

        if not self.running:
            self.running = True

    async def _start(self):
        """Run the event loop and force the on_stop event."""
        try:
            await self._run()
        finally:
            await self.on_stop()

    async def resume(self):
        """Resume the actor."""
        await self.on_resume()
        self._pause_lock.release()

    async def pause(self):
        """Pause the actor."""
        await self._pause_lock.acquire()
        await self.on_pause()

    async def _block_if_paused(self):
        """Block on the pause lock."""
        if self.paused:
            await self._pause_lock.acquire()
            await self._pause_lock.release()

    async def _run(self):
        """Main event loop."""
        while self.running:
            await self._block_if_paused()

            await self._process()

    async def publish(self, data):
        """Push data to the outbox."""
        await self.outbox.put(data)

    async def stop(self):
        """Stop the actor."""
        self.inbox = None
        self.outbox = None
        self.running = False
        self._force_stop = True

        self._stop_event.set()

        try:
            self._pause_lock.release()
        except RuntimeError:
            pass

    async def _process(self):
        """Process incoming messages."""
        if not self.inbox:
            return

        pending = {self.inbox.get(), self._stop_event.wait()}
        result = await get_first_completed(pending, self._loop)

        if self.running:
            await self.on_message(result)

    async def on_message(self, data):
        """Called when the actor receives a message."""
        raise NotImplementedError

    def on_init(self):
        """Called after the actor class is instantiated."""
        pass

    async def on_start(self):
        """Called before the actor starts ingesting the inbox."""
        pass

    async def on_stop(self):
        """Called after actor dies."""
        pass

    async def on_pause(self):
        """Called before the actor is paused."""
        pass

    async def on_resume(self):
        """Called before the actor is resumed."""
        pass
コード例 #4
0
class SubscribeListener(SubscribeCallback):
    def __init__(self):
        self.connected = False
        self.connected_event = Event()
        self.disconnected_event = Event()
        self.presence_queue = Queue()
        self.message_queue = Queue()
        self.error_queue = Queue()

    def status(self, pubnub, status):
        if utils.is_subscribed_event(status) and not self.connected_event.is_set():
            self.connected_event.set()
        elif utils.is_unsubscribed_event(status) and not self.disconnected_event.is_set():
            self.disconnected_event.set()
        elif status.is_error():
            self.error_queue.put_nowait(status.error_data.exception)

    def message(self, pubnub, message):
        self.message_queue.put_nowait(message)

    def presence(self, pubnub, presence):
        self.presence_queue.put_nowait(presence)

    @asyncio.coroutine
    def _wait_for(self, coro):
        scc_task = asyncio.ensure_future(coro)
        err_task = asyncio.ensure_future(self.error_queue.get())

        yield from asyncio.wait([
            scc_task,
            err_task
        ], return_when=asyncio.FIRST_COMPLETED)

        if err_task.done() and not scc_task.done():
            if not scc_task.cancelled():
                scc_task.cancel()
            raise err_task.result()
        else:
            if not err_task.cancelled():
                err_task.cancel()
            return scc_task.result()

    @asyncio.coroutine
    def wait_for_connect(self):
        if not self.connected_event.is_set():
            yield from self._wait_for(self.connected_event.wait())
        else:
            raise Exception("instance is already connected")

    @asyncio.coroutine
    def wait_for_disconnect(self):
        if not self.disconnected_event.is_set():
            yield from self._wait_for(self.disconnected_event.wait())
        else:
            raise Exception("instance is already disconnected")

    @asyncio.coroutine
    def wait_for_message_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = yield from self._wait_for(self.message_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.message_queue.task_done()

    @asyncio.coroutine
    def wait_for_presence_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = yield from self._wait_for(self.presence_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.presence_queue.task_done()
コード例 #5
0
ファイル: session.py プロジェクト: WebDevCaptain/lbry-sdk
class SessionBase(asyncio.Protocol):
    """Base class of networking sessions.

    There is no client / server distinction other than who initiated
    the connection.

    To initiate a connection to a remote server pass host, port and
    proxy to the constructor, and then call create_connection().  Each
    successful call should have a corresponding call to close().

    Alternatively if used in a with statement, the connection is made
    on entry to the block, and closed on exit from the block.
    """

    max_errors = 10

    def __init__(self, *, framer=None, loop=None):
        self.framer = framer or self.default_framer()
        self.loop = loop or asyncio.get_event_loop()
        self.logger = logging.getLogger(self.__class__.__name__)
        self.transport = None
        # Set when a connection is made
        self._address = None
        self._proxy_address = None
        # For logger.debug messages
        self.verbosity = 0
        # Cleared when the send socket is full
        self._can_send = Event()
        self._can_send.set()
        self._pm_task = None
        self._task_group = TaskGroup(self.loop)
        # Force-close a connection if a send doesn't succeed in this time
        self.max_send_delay = 60
        # Statistics.  The RPC object also keeps its own statistics.
        self.start_time = time.perf_counter()
        self.errors = 0
        self.send_count = 0
        self.send_size = 0
        self.last_send = self.start_time
        self.recv_count = 0
        self.recv_size = 0
        self.last_recv = self.start_time
        self.last_packet_received = self.start_time

    async def _limited_wait(self, secs):
        try:
            await asyncio.wait_for(self._can_send.wait(), secs)
        except asyncio.TimeoutError:
            self.abort()
            raise asyncio.TimeoutError(f'task timed out after {secs}s')

    async def _send_message(self, message):
        if not self._can_send.is_set():
            await self._limited_wait(self.max_send_delay)
        if not self.is_closing():
            framed_message = self.framer.frame(message)
            self.send_size += len(framed_message)
            self.send_count += 1
            self.last_send = time.perf_counter()
            if self.verbosity >= 4:
                self.logger.debug(f'Sending framed message {framed_message}')
            self.transport.write(framed_message)

    def _bump_errors(self):
        self.errors += 1
        if self.errors >= self.max_errors:
            # Don't await self.close() because that is self-cancelling
            self._close()

    def _close(self):
        if self.transport:
            self.transport.close()

    # asyncio framework
    def data_received(self, framed_message):
        """Called by asyncio when a message comes in."""
        self.last_packet_received = time.perf_counter()
        if self.verbosity >= 4:
            self.logger.debug(f'Received framed message {framed_message}')
        self.recv_size += len(framed_message)
        self.framer.received_bytes(framed_message)

    def pause_writing(self):
        """Transport calls when the send buffer is full."""
        if not self.is_closing():
            self._can_send.clear()
            self.transport.pause_reading()

    def resume_writing(self):
        """Transport calls when the send buffer has room."""
        if not self._can_send.is_set():
            self._can_send.set()
            self.transport.resume_reading()

    def connection_made(self, transport):
        """Called by asyncio when a connection is established.

        Derived classes overriding this method must call this first."""
        self.transport = transport
        # This would throw if called on a closed SSL transport.  Fixed
        # in asyncio in Python 3.6.1 and 3.5.4
        peer_address = transport.get_extra_info('peername')
        # If the Socks proxy was used then _address is already set to
        # the remote address
        if self._address:
            self._proxy_address = peer_address
        else:
            self._address = peer_address
        self._pm_task = self.loop.create_task(self._receive_messages())

    def connection_lost(self, exc):
        """Called by asyncio when the connection closes.

        Tear down things done in connection_made."""
        self._address = None
        self.transport = None
        self._task_group.cancel()
        if self._pm_task:
            self._pm_task.cancel()
        # Release waiting tasks
        self._can_send.set()

    # External API
    def default_framer(self):
        """Return a default framer."""
        raise NotImplementedError

    def peer_address(self):
        """Returns the peer's address (Python networking address), or None if
        no connection or an error.

        This is the result of socket.getpeername() when the connection
        was made.
        """
        return self._address

    def peer_address_str(self):
        """Returns the peer's IP address and port as a human-readable
        string."""
        if not self._address:
            return 'unknown'
        ip_addr_str, port = self._address[:2]
        if ':' in ip_addr_str:
            return f'[{ip_addr_str}]:{port}'
        else:
            return f'{ip_addr_str}:{port}'

    def is_closing(self):
        """Return True if the connection is closing."""
        return not self.transport or self.transport.is_closing()

    def abort(self):
        """Forcefully close the connection."""
        if self.transport:
            self.transport.abort()

    # TODO: replace with synchronous_close
    async def close(self, *, force_after=30):
        """Close the connection and return when closed."""
        self._close()
        if self._pm_task:
            with suppress(CancelledError):
                await asyncio.wait([self._pm_task], timeout=force_after)
                self.abort()
                await self._pm_task

    def synchronous_close(self):
        self._close()
        if self._pm_task and not self._pm_task.done():
            self._pm_task.cancel()
コード例 #6
0
ファイル: async_pool.py プロジェクト: artemmus/mysql_executor
class AsyncConnectionPool:
    """Object manages asynchronous connections.

    :param int size: size (number of connection) of the pool.
    :param float queue_timeout: time out when client is waiting connection
        from pool
    :param loop: event loop, if not passed then default will be used
    :param config: MySql connection config see
        `doc. <http://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html>`_
    :raise ValueError: if the `size` is inappropriate
    """
    def __init__(self, size=1, queue_timeout=15.0, *, loop=None, **config):
        assert size > 0, 'DBPool.size must be greater than 0'
        if size < 1:
            raise ValueError('DBPool.size is less than 1, '
                             'connections won"t be established')
        self._pool = set()
        self._busy_items = set()
        self._size = size
        self._pending_futures = deque()
        self._queue_timeout = queue_timeout
        self._loop = loop or asyncio.get_event_loop()
        self.config = config

        self._shutdown_event = Event(loop=self._loop)
        self._shutdown_event.set()

    @property
    def queue_timeout(self):
        """Number of seconds to wait a connection from the pool,
        before TimeoutError occurred

        :rtype: float
        """
        return self._queue_timeout

    @queue_timeout.setter
    def queue_timeout(self, value):
        """Sets a timeout for :attr:`queue_timeout`

        :param float value: number of seconds
        """
        if not isinstance(value, (float, int)):
            raise ValueError('Float or integer type expected')
        self._queue_timeout = value

    @property
    def size(self):
        """Size of pool

        :rtype: int
        """
        return self._size

    def __len__(self):
        """Number of allocated pool's slots

        :rtype: int
        """
        return len(self._pool)

    @property
    def free_count(self):
        """Number of free pool's slots

        :rtype: int
        """
        return self.size - len(self._busy_items)

    @asyncio.coroutine
    def get(self):
        """Coroutine. Returns an opened connection from pool.
        If coroutine invoked when all connections have been issued, then
        caller will blocked until some connection will be released.

        Also, the class provides context manager for getting connection
        and automatically freeing it. Example:
        >>> with (yield from pool) as cnx:
        >>>     ...

        :rtype: AsyncMySQLConnection
        :raise: concurrent.futures.TimeoutError()
        """
        cnx = None

        yield from self._shutdown_event.wait()

        for free_client in self._pool - self._busy_items:
            cnx = free_client
            self._busy_items.add(cnx)
            break
        else:
            if len(self) < self.size:
                cnx = AsyncMySQLConnection(loop=self._loop)
                self._pool.add(cnx)
                self._busy_items.add(cnx)

                try:
                    yield from cnx.connect(**self.config)
                except:
                    self._pool.remove(cnx)
                    self._busy_items.remove(cnx)
                    raise

        if not cnx:
            queue_future = Future(loop=self._loop)
            self._pending_futures.append(queue_future)
            try:
                cnx = yield from asyncio.wait_for(queue_future,
                                                  self.queue_timeout,
                                                  loop=self._loop)
                self._busy_items.add(cnx)
            except TimeoutError:
                raise TimeoutError('Database pool is busy')
            finally:
                try:
                    self._pending_futures.remove(queue_future)
                except ValueError:
                    pass

        return cnx

    def release(self, connection):
        """Frees connection. After that the connection can be issued
        by :func:`get`.

        :param AsyncMySQLConnection connection: a connection received
            from :func:`get`
        """
        if len(self._pending_futures):
            f = self._pending_futures.popleft()
            f.set_result(connection)
        else:
            self._busy_items.remove(connection)

    @asyncio.coroutine
    def shutdown(self):
        """Coroutine. Closes all connections and purge queue of a waiting
        for connection.
        """
        self._shutdown_event.clear()
        try:
            for cnx in self._pool:
                yield from cnx.disconnect()

            for f in self._pending_futures:
                f.cancel()

            self._pending_futures.clear()
            self._pool = set()
            self._busy_items = set()
        finally:
            self._shutdown_event.set()

    def __enter__(self):
        raise RuntimeError(
            '"yield from" should be used as context manager expression')

    def __exit__(self, *args):
        # This must exist because __enter__ exists, even though that
        # always raises; that's how the with-statement works.
        pass

    @asyncio.coroutine
    def __iter__(self):
        cnx = yield from self.get()
        return ContextManager(self, cnx)
コード例 #7
0
ファイル: worker.py プロジェクト: LubinLew/wazuh
    async def process_files_from_master(self, name: str, file_received: asyncio.Event):
        """Perform relevant actions for each file according to its status.

        Process integrity files coming from the master. It updates necessary information and sends the master
        any required extra_valid files.

        Parameters
        ----------
        name : str
            Task ID that was waiting for the file to be received.
        file_received : asyncio.Event
            Asyncio event that is unlocked once the file has been received.
        """
        logger = self.task_loggers['Integrity sync']

        try:
            await asyncio.wait_for(file_received.wait(),
                                   timeout=self.cluster_items['intervals']['communication']['timeout_receiving_file'])
        except Exception:
            await self.send_request(
                command=b'syn_i_w_m_r',
                data=b'None ' + json.dumps(timeout_exc := WazuhClusterError(
                    3039, extra_message=f'Integrity sync at {self.name}'), cls=c_common.WazuhJSONEncoder).encode())
            raise timeout_exc

        if isinstance(self.sync_tasks[name].filename, Exception):
            exc_info = json.dumps(exception.WazuhClusterError(
                1000, extra_message=str(self.sync_tasks[name].filename)), cls=c_common.WazuhJSONEncoder)
            await self.send_request(command=b'syn_i_w_m_r', data=b'None ' + exc_info.encode())
            raise self.sync_tasks[name].filename

        zip_path = ""
        # Path of the zip containing a JSON with metadata and files to be updated in this worker node.
        received_filename = self.sync_tasks[name].filename

        try:
            self.integrity_sync_status['date_start'] = datetime.utcnow().timestamp()
            logger.info("Starting.")

            """
            - zip_path contains the path of the unzipped directory
            - ko_files contains a Dict with this structure:
              {'missing': {'<file_path>': {<MD5, merged, merged_name, etc>}, ...},
               'shared': {...}, 'extra': {...}, 'extra_valid': {...}}
            """
            ko_files, zip_path = await cluster.run_in_pool(self.loop, self.manager.task_pool, cluster.decompress_files,
                                                           received_filename)
            logger.info("Files to create: {} | Files to update: {} | Files to delete: {} | Files to send: {}".format(
                len(ko_files['missing']), len(ko_files['shared']), len(ko_files['extra']), len(ko_files['extra_valid']))
            )

            if ko_files['shared'] or ko_files['missing'] or ko_files['extra']:
                # Update or remove files in this worker node according to their status (missing, extra or shared).
                logger.debug("Worker does not meet integrity checks. Actions required.")
                logger.debug("Updating local files: Start.")
                await cluster.run_in_pool(self.loop, self.manager.task_pool, self.update_master_files_in_worker,
                                          ko_files, zip_path, self.cluster_items, self.task_loggers['Integrity sync'])
                logger.debug("Updating local files: End.")

            # Send extra valid files to the master.
            if ko_files['extra_valid']:
                logger.debug("Master requires some worker files.")
                asyncio.create_task(self.sync_extra_valid(ko_files['extra_valid']))
            else:
                logger.info(
                    f"Finished in {datetime.utcnow().timestamp() - self.integrity_sync_status['date_start']:.3f}s.")

        except exception.WazuhException as e:
            logger.error(f"Error synchronizing extra valid files: {e}")
            await self.send_request(command=b'syn_i_w_m_r',
                                    data=b'None ' + json.dumps(e, cls=c_common.WazuhJSONEncoder).encode())
        except Exception as e:
            logger.error(f"Error synchronizing extra valid files: {e}")
            exc_info = json.dumps(exception.WazuhClusterError(1000, extra_message=str(e)),
                                  cls=c_common.WazuhJSONEncoder)
            await self.send_request(command=b'syn_i_w_m_r', data=b'None ' + exc_info.encode())
        finally:
            zip_path and shutil.rmtree(zip_path)
コード例 #8
0
class SubscribeListener(SubscribeCallback):
    def __init__(self):
        self.connected = False
        self.connected_event = Event()
        self.disconnected_event = Event()
        self.presence_queue = Queue()
        self.message_queue = Queue()
        self.error_queue = Queue()

    def status(self, pubnub, status):
        if utils.is_subscribed_event(
                status) and not self.connected_event.is_set():
            self.connected_event.set()
        elif utils.is_unsubscribed_event(
                status) and not self.disconnected_event.is_set():
            self.disconnected_event.set()
        elif status.is_error():
            self.error_queue.put_nowait(status.error_data.exception)

    def message(self, pubnub, message):
        self.message_queue.put_nowait(message)

    def presence(self, pubnub, presence):
        self.presence_queue.put_nowait(presence)

    async def _wait_for(self, coro):
        scc_task = asyncio.ensure_future(coro)
        err_task = asyncio.ensure_future(self.error_queue.get())

        await asyncio.wait([scc_task, err_task],
                           return_when=asyncio.FIRST_COMPLETED)

        if err_task.done() and not scc_task.done():
            if not scc_task.cancelled():
                scc_task.cancel()
            raise err_task.result()
        else:
            if not err_task.cancelled():
                err_task.cancel()
            return scc_task.result()

    async def wait_for_connect(self):
        if not self.connected_event.is_set():
            await self._wait_for(self.connected_event.wait())
        else:
            raise Exception("instance is already connected")

    async def wait_for_disconnect(self):
        if not self.disconnected_event.is_set():
            await self._wait_for(self.disconnected_event.wait())
        else:
            raise Exception("instance is already disconnected")

    async def wait_for_message_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = await self._wait_for(self.message_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.message_queue.task_done()

    async def wait_for_presence_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = await self._wait_for(self.presence_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.presence_queue.task_done()
コード例 #9
0
    async def connect(self, addr, port, timeout=3):
        self.logger.debug('hxsocks2 send connect request')
        async with self._lock:
            if self.connection_lost:
                self._manager.remove(self)
                raise ConnectionResetError(0, 'hxs connection lost')
            if not self.connected:
                try:
                    await self.get_key(timeout)
                except asyncio.CancelledError:
                    raise
                except Exception as err:
                    self.logger.error('%s get_key %r', self.name, err)
                    # self.logger.error(traceback.format_exc())
                    try:
                        self.remote_writer.close()
                    except (OSError, AttributeError):
                        pass
                    raise ConnectionResetError(0, 'hxs get_key failed.')
        # send connect request
        payload = b''.join([
            chr(len(addr)).encode('latin1'),
            addr.encode(),
            struct.pack('>H', port),
            b'\x00' * random.randint(64, 255),
        ])
        stream_id = self._next_stream_id
        self._next_stream_id += 1
        if self._next_stream_id > MAX_STREAM_ID:
            self.logger.error('MAX_STREAM_ID reached')
            self._manager.remove(self)

        await self.send_frame(1, 0, stream_id, payload)

        # wait for server response
        event = Event()
        self._client_status[stream_id] = event

        # await event.wait()
        fut = event.wait()
        try:
            await asyncio.wait_for(fut, timeout=timeout)
        except asyncio.TimeoutError:
            self.logger.error('no response from %s, timeout=%.3f', self.name,
                              timeout)
            del self._client_status[stream_id]
            self.print_status()
            await self.send_ping()
            raise

        del self._client_status[stream_id]

        if self._stream_status[stream_id] == OPEN:
            socketpair_a, socketpair_b = socket.socketpair()
            if sys.platform == 'win32':
                socketpair_a.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF,
                                        65536)
                socketpair_b.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF,
                                        65536)

            reader, writer = await asyncio.open_connection(sock=socketpair_b)
            writer.transport.set_write_buffer_limits(0, 0)

            self._client_writer[stream_id] = writer
            self._last_active[stream_id] = time.monotonic()
            # start forwarding
            asyncio.ensure_future(self.read_from_client(stream_id, reader))
            return socketpair_a
        raise ConnectionResetError(
            0, 'remote connect to %s:%d failed.' % (addr, port))
コード例 #10
0
ファイル: channel.py プロジェクト: tbug/aiochannel
class Channel(object):
    """
        A Channel is a closable queue. A Channel is considered "finished" when
        it is closed and drained (unlike a queue which is "finished" when the queue
        is empty)
    """

    def __init__(self, maxsize=0, *, loop=None):
        if loop is None:
            self._loop = get_event_loop()
        else:
            self._loop = loop

        if not isinstance(maxsize, int) or maxsize < 0:
            raise TypeError("maxsize must be an integer >= 0 (default is 0)")
        self._maxsize = maxsize

        # Futures.
        self._getters = deque()
        self._putters = deque()

        # "finished" means channel is closed and drained
        self._finished = Event(loop=self._loop)
        self._close = Event(loop=self._loop)

        self._init()

    def _init(self):
        self._queue = deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    def _wakeup_next(self, waiters):
        # Wake up the next waiter (if any) that isn't cancelled.
        while waiters:
            waiter = waiters.popleft()
            if not waiter.done():
                waiter.set_result(None)
                break

    def __repr__(self):
        return '<{} at {:#x} maxsize={!r} qsize={!r}>'.format(
            type(self).__name__, id(self), self._maxsize, self.qsize())

    def __str__(self):
        return '<{} maxsize={!r} qsize={!r}>'.format(
            type(self).__name__, self._maxsize, self.qsize())

    def qsize(self):
        """Number of items in the channel buffer."""
        return len(self._queue)

    @property
    def maxsize(self):
        """Number of items allowed in the channel buffer."""
        return self._maxsize

    def empty(self):
        """Return True if the channel is empty, False otherwise."""
        return not self._queue

    def full(self):
        """Return True if there are maxsize items in the channel.
        Note: if the Channel was initialized with maxsize=0 (the default),
        then full() is never True.
        """
        if self._maxsize <= 0:
            return False
        else:
            return self.qsize() >= self._maxsize

    @coroutine
    def put(self, item):
        """Put an item into the channel.
        If the channel is full, wait until a free
        slot is available before adding item.
        If the channel is closed or closing, raise ChannelClosed.
        This method is a coroutine.
        """
        while self.full() and not self._close.is_set():
            putter = Future(loop=self._loop)
            self._putters.append(putter)
            try:
                yield from putter
            except ChannelClosed:
                raise
            except:
                putter.cancel()  # Just in case putter is not done yet.
                if not self.full() and not putter.cancelled():
                    # We were woken up by get_nowait(), but can't take
                    # the call.  Wake up the next in line.
                    self._wakeup_next(self._putters)
                raise
        return self.put_nowait(item)

    def put_nowait(self, item):
        """Put an item into the channel without blocking.
        If no free slot is immediately available, raise ChannelFull.
        """
        if self.full():
            raise ChannelFull
        if self._close.is_set():
            raise ChannelClosed
        self._put(item)
        self._wakeup_next(self._getters)

    @coroutine
    def get(self):
        """Remove and return an item from the channel.
        If channel is empty, wait until an item is available.
        This method is a coroutine.
        """
        while self.empty() and not self._close.is_set():
            getter = Future(loop=self._loop)
            self._getters.append(getter)
            try:
                yield from getter
            except ChannelClosed:
                raise
            except:
                getter.cancel()  # Just in case getter is not done yet.
                if not self.empty() and not getter.cancelled():
                    # We were woken up by put_nowait(), but can't take
                    # the call.  Wake up the next in line.
                    self._wakeup_next(self._getters)
                raise
        return self.get_nowait()

    def get_nowait(self):
        """Remove and return an item from the channel.
        Return an item if one is immediately available, else raise ChannelEmpty.
        """
        if self.empty():
            if self._close.is_set():
                raise ChannelClosed
            else:
                raise ChannelEmpty
        item = self._get()
        if self.empty() and self._close.is_set():
            # if empty _after_ we retrieved an item AND marked for closing,
            # set the finished flag
            self._finished.set()
        self._wakeup_next(self._putters)
        return item

    @coroutine
    def join(self):
        """Block until channel is closed and channel is drained
        """
        yield from self._finished.wait()

    def close(self):
        """Marks the channel is closed and throw a ChannelClosed in all pending putters"""
        self._close.set()
        # cancel putters
        for putter in self._putters:
            putter.set_exception(ChannelClosed())
        # cancel getters that can't ever return (as no more items can be added)
        while len(self._getters) > self.qsize():
            getter = self._getters.pop()
            getter.set_exception(ChannelClosed())

        if self.empty():
            # already empty, mark as finished
            self._finished.set()

    def closed(self):
        """Returns True if the Channel is marked as closed"""
        return self._close.is_set()

    @coroutine
    def __aiter__(self):  # pragma: no cover
        """Returns an async iterator (self)"""
        return self

    @coroutine
    def __anext__(self):  # pragma: no cover
        try:
            data = yield from self.get()
        except ChannelClosed:
            raise StopAsyncIteration
        else:
            return data

    def __iter__(self):
        return iter(self._queue)
コード例 #11
0
ファイル: master.py プロジェクト: jaydave/wazuh
    async def sync_integrity(self, task_id: str, received_file: asyncio.Event):
        """Perform the integrity synchronization process by comparing local and received files.

        It waits until the worker sends its integrity metadata. Once received, they are unzipped.

        The information inside the unzipped files_metadata.json file (integrity metadata) is compared with the
        local one (updated every self.cluster_items['intervals']['master']['recalculate_integrity'] seconds).
        All files that are different (new, deleted, with a different MD5, etc) are classified into four groups:
        shared, missing, extra and extra_valid.

        Finally, a zip containing this classification (files_metadata.json) and the files that are missing
        or that must be updated are sent to the worker.

        Parameters
        ----------
        task_id : str
            ID of the asyncio task in charge of doing the sync process.
        received_file : asyncio.Event
            Asyncio event that is holding a lock while the files are not received.

        Returns
        -------
        bytes
            Result.
        bytes
            Response message.
        """
        logger = self.task_loggers['Integrity check']
        date_start_master = datetime.now()

        logger.debug("Waiting to receive zip file from worker.")
        await asyncio.wait_for(received_file.wait(),
                               timeout=self.cluster_items['intervals']['communication']['timeout_receiving_file'])

        # Full path where the zip sent by the worker is located.
        received_filename = self.sync_tasks[task_id].filename
        if isinstance(received_filename, Exception):
            raise received_filename

        logger.debug(f"Received file from worker: '{received_filename}'")

        # Path to metadata file (files_metadata.json) and to zipdir (directory with decompressed files).
        files_metadata, decompressed_files_path = await wazuh.core.cluster.cluster.decompress_files(received_filename)
        # There are no files inside decompressed_files_path, only files_metadata.json which has already been loaded.
        shutil.rmtree(decompressed_files_path)
        logger.info(f"Starting. Received metadata of {len(files_metadata)} files.")

        # Classify files in shared, missing, extra and extra valid.
        worker_files_ko, counts = wazuh.core.cluster.cluster.compare_files(self.server.integrity_control,
                                                                           files_metadata, self.name)

        self.integrity_check_status.update({'date_start_master': date_start_master, 'date_end_master': datetime.now()})
        total_time = (self.integrity_check_status['date_end_master'] - date_start_master).total_seconds()

        # Get the total number of files that require some change.
        if not functools.reduce(operator.add, map(len, worker_files_ko.values())):
            logger.info(f"Finished in {total_time:.3f}s. Sync not required.")
            result = await self.send_request(command=b'syn_m_c_ok', data=b'')
        else:
            logger.info(f"Finished in {total_time:.3f}s. Sync required.")

            logger = self.task_loggers['Integrity sync']
            logger.info("Starting.")
            self.integrity_sync_status.update({'tmp_date_start_master': datetime.now(), 'total_files': counts,
                                               'total_extra_valid': 0})
            logger.info("Files to create in worker: {} | Files to update in worker: {} | Files to delete in worker: {} "
                        "| Files to receive: {}".format(len(worker_files_ko['missing']), len(worker_files_ko['shared']),
                                                        len(worker_files_ko['extra']), len(worker_files_ko['extra_valid']))
                        )

            # Compress data: master files (only KO shared and missing).
            logger.debug("Compressing files to be synced in worker.")
            master_files_paths = worker_files_ko['shared'].keys() | worker_files_ko['missing'].keys()
            compressed_data = wazuh.core.cluster.cluster.compress_files(self.name, master_files_paths, worker_files_ko)

            logger.debug("Zip with files to be synced sent to worker.")
            try:
                # Start the synchronization process with the worker and get a taskID.
                task_id = await self.send_request(command=b'syn_m_c', data=b'')
                if isinstance(task_id, Exception) or task_id.startswith(b'Error'):
                    exc_info = task_id if isinstance(task_id, Exception) else \
                        exception.WazuhClusterError(code=3016, extra_message=str(task_id))
                    task_id = b'None'
                    raise exc_info

                # Send zip file to the worker into chunks.
                await self.send_file(compressed_data)

                # Finish the synchronization process and notify where the file corresponding to the taskID is located.
                result = await self.send_request(command=b'syn_m_c_e',
                                                 data=task_id + b' ' + os.path.relpath(
                                                     compressed_data, common.wazuh_path).encode())
                if isinstance(result, Exception):
                    raise result
                elif result.startswith(b'Error'):
                    raise exception.WazuhClusterError(3016, extra_message=result.decode())
            except exception.WazuhException as e:
                # Notify error to worker and delete its received file.
                self.logger.error(f"Error sending files information: {e}")
                result = await self.send_request(command=b'syn_m_c_r', data=task_id + b' ' +
                                                 json.dumps(e, cls=c_common.WazuhJSONEncoder).encode())
            except Exception as e:
                # Notify error to worker and delete its received file.
                self.logger.error(f"Error sending files information: {e}")
                exc_info = json.dumps(exception.WazuhClusterError(code=1000, extra_message=str(e)),
                                      cls=c_common.WazuhJSONEncoder).encode()
                result = await self.send_request(command=b'syn_m_c_r', data=task_id + b' ' + exc_info)
            finally:
                # Remove local file.
                os.unlink(compressed_data)
                logger.debug("Finished sending files to worker.")
                # Log 'Finished in' message only if there are no extra_valid files to sync.
                if not worker_files_ko['extra_valid']:
                    self.integrity_sync_status['date_start_master'] = self.integrity_sync_status['tmp_date_start_master']
                    self.integrity_sync_status['date_end_master'] = datetime.now()
                    logger.info("Finished in {:.3f}s.".format((self.integrity_sync_status['date_end_master'] -
                                                               self.integrity_sync_status['date_start_master'])
                                                              .total_seconds()))

        self.sync_integrity_free = True
        return result
コード例 #12
0
class MapAsyncIterator:
    """Map an AsyncIterable over a callback function.

    Given an AsyncIterable and a callback function, return an AsyncIterator which
    produces values mapped via calling the callback function.

    When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
    be closed.
    """
    def __init__(
        self,
        iterable: AsyncIterable,
        callback: Callable,
        reject_callback: Optional[Callable] = None,
    ) -> None:
        self.iterator = iterable.__aiter__()
        self.callback = callback
        self.reject_callback = reject_callback
        self._close_event = Event()

    def __aiter__(self) -> "MapAsyncIterator":
        """Get the iterator object."""
        return self

    async def __anext__(self) -> Any:
        """Get the next value of the iterator."""
        if self.is_closed:
            if not isasyncgen(self.iterator):
                raise StopAsyncIteration
            value = await self.iterator.__anext__()
            result = self.callback(value)

        else:
            aclose = ensure_future(self._close_event.wait())
            anext = ensure_future(self.iterator.__anext__())

            try:
                pending: Set[Task] = (await
                                      wait([aclose, anext],
                                           return_when=FIRST_COMPLETED))[1]
            except CancelledError:
                # cancel underlying tasks and close
                aclose.cancel()
                anext.cancel()
                await self.aclose()
                raise  # re-raise the cancellation

            for task in pending:
                task.cancel()

            if aclose.done():
                raise StopAsyncIteration

            error = anext.exception()
            if error:
                if not self.reject_callback or isinstance(
                        error, (StopAsyncIteration, GeneratorExit)):
                    raise error
                result = self.reject_callback(error)
            else:
                value = anext.result()
                result = self.callback(value)

        return await result if isawaitable(result) else result

    async def athrow(
        self,
        type_: Union[BaseException, Type[BaseException]],
        value: Optional[BaseException] = None,
        traceback: Optional[TracebackType] = None,
    ) -> None:
        """Throw an exception into the asynchronous iterator."""
        if not self.is_closed:
            athrow = getattr(self.iterator, "athrow", None)
            if athrow:
                await athrow(type_, value, traceback)
            else:
                await self.aclose()
                if value is None:
                    if traceback is None:
                        raise type_
                    value = (type_ if isinstance(value, BaseException) else
                             cast(Type[BaseException], type_)())
                if traceback is not None:
                    value = value.with_traceback(traceback)
                raise value

    async def aclose(self) -> None:
        """Close the iterator."""
        if not self.is_closed:
            aclose = getattr(self.iterator, "aclose", None)
            if aclose:
                try:
                    await aclose()
                except RuntimeError:
                    pass
            self.is_closed = True

    @property
    def is_closed(self) -> bool:
        """Check whether the iterator is closed."""
        return self._close_event.is_set()

    @is_closed.setter
    def is_closed(self, value: bool) -> None:
        """Mark the iterator as closed."""
        if value:
            self._close_event.set()
        else:
            self._close_event.clear()
コード例 #13
0
ファイル: justredis.py プロジェクト: tiwariashish86/cobra
class PubSubInstance:
    __slots__ = '_pubsub', '_encoder', '_decoder', '_closed', '_messages', '_event'

    def __init__(self, pubsub, encoder, decoder):
        self._pubsub = pubsub
        self._encoder = encoder or utf8_encode
        self._decoder = decoder
        self._closed = False
        self._messages = deque()
        self._event = Event()

    async def __aenter__(self):
        if self._closed:
            raise RedisError('Pub/sub instance closed')
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        await self.aclose()

    async def aclose(self):
        if not self._closed:
            self._closed = True
            try:
                await self._pubsub.unregister(self)
            except Exception:
                pass
            self._messages = None
            self._decoder = None
            self._encoder = None
            self._pubsub = None
            self._event = None

    async def add(self, channels=None, patterns=None):
        await self._cmd(self._pubsub.register, channels, patterns)

    # TODO (question) should we removed the self._messages that are not related to this channels and patterns (left overs)?
    async def remove(self, channels=None, patterns=None):
        await self._cmd(self._pubsub.unregister, channels, patterns)

    async def message(self, timeout=None):
        if self._closed:
            raise RedisError('Pub/sub instance closed')
        msg = self._get_message()
        if msg is not None:
            return msg
        self._event.clear()
        # We check connection here to notify the end user if there is an connection error...
        await self._pubsub.check_connection(self)
        if timeout is None:
            await self._event.wait()
        else:
            try:
                await wait_for(self._event.wait(), timeout)
            except AsyncIOTimeoutError:
                pass
        return self._get_message()

    async def ping(self, message=None):
        if self._closed:
            raise RedisError('Pub/sub instance closed')
        await self._pubsub.ping(message)

    async def _cmd(self, cmd, channels, patterns):
        if self._closed:
            raise RedisError('Pub/sub instance closed')
        if channels:
            if isinstance(channels, (str, bytes)):
                channels = [channels]
            channels = [self._encoder(x) for x in channels]
        if patterns:
            if isinstance(patterns, (str, bytes)):
                patterns = [patterns]
            patterns = [self._encoder(x) for x in patterns]
        await cmd(self, channels, patterns)

    def _add_message(self, msg):
        if self._messages is not None:
            self._messages.append(msg)
            self._event.set()

    def _get_message(self):
        try:
            msg = self._messages.popleft()
        except IndexError:
            return None
        if self._decoder:
            msg = self._decoder(msg)
        if isinstance(msg, Exception):
            raise msg
        return msg