예제 #1
0
 def __init__(self):
     self.connected = False
     self.connected_event = Event()
     self.disconnected_event = Event()
     self.presence_queue = Queue()
     self.message_queue = Queue()
     self.error_queue = Queue()
예제 #2
0
    def wait(self, promise, timeout=None):
        e = Event()

        def on_resolve_or_reject(_):
            e.set()

        promise._then(on_resolve_or_reject, on_resolve_or_reject)

        # We can't use the timeout in Asyncio event
        e.wait()
예제 #3
0
async def update_lifetime(
    loop: asyncio.BaseEventLoop, ev: asyncio.Event, lifetime: int
):
    i = 0
    while not ev.is_set():
        logger.info("update lifetime:%d %d", i, lifetime + 20)  # xxx: +20?
        await asyncio.wait([loop.run_in_executor(None, _do), asyncio.sleep(1)])
        i += 1
예제 #4
0
파일: channel.py 프로젝트: tbug/aiochannel
    def __init__(self, maxsize=0, *, loop=None):
        if loop is None:
            self._loop = get_event_loop()
        else:
            self._loop = loop

        if not isinstance(maxsize, int) or maxsize < 0:
            raise TypeError("maxsize must be an integer >= 0 (default is 0)")
        self._maxsize = maxsize

        # Futures.
        self._getters = deque()
        self._putters = deque()

        # "finished" means channel is closed and drained
        self._finished = Event(loop=self._loop)
        self._close = Event(loop=self._loop)

        self._init()
예제 #5
0
async def update_lifetime(
    loop: asyncio.BaseEventLoop, ev: asyncio.Event, lifetime: int
):
    with spawn_task_scope(loop) as spawn:
        i = 0
        while not ev.is_set():
            logger.info("update lifetime:%d %d", i, lifetime + 20)  # xxx: +20?
            spawn(_do)
            await asyncio.sleep(1)
            i += 1
예제 #6
0
    def _perform_heartbeat_loop(self):
        if self._heartbeat_call is not None:
            # TODO: cancel call
            pass

        cancellation_event = Event()
        state_payload = self._subscription_state.state_payload()
        presence_channels = self._subscription_state.prepare_channel_list(False)
        presence_groups = self._subscription_state.prepare_channel_group_list(False)

        if len(presence_channels) == 0 and len(presence_groups) == 0:
            return

        try:
            heartbeat_call = (Heartbeat(self._pubnub)
                              .channels(presence_channels)
                              .channel_groups(presence_groups)
                              .state(state_payload)
                              .cancellation_event(cancellation_event)
                              .future())

            envelope = yield from heartbeat_call

            heartbeat_verbosity = self._pubnub.config.heartbeat_notification_options
            if envelope.status.is_error:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL or \
                        heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_stateus(envelope.status)
            else:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_stateus(envelope.status)

        except PubNubAsyncioException as e:
            pass
            # TODO: check correctness
            # if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
            #     self._start_subscribe_loop()
            # else:
            #     self._listener_manager.announce_status(e.status)
        finally:
            cancellation_event.set()
예제 #7
0
파일: actor.py 프로젝트: iakinsey/illume
    def __init__(self, inbox, outbox, loop=None):
        self.inbox = inbox
        self.outbox = outbox

        if not loop:
            loop = get_event_loop()

        self._loop = loop
        self._pause_lock = Lock(loop=self._loop)
        self._stop_event = Event(loop=self._loop)
        self._test = None
        self.__testy = None

        self.on_init()
예제 #8
0
파일: compound.py 프로젝트: iakinsey/illume
class CompoundQueue(GeneratorQueue):
    stop_event = None
    ready = None
    loop = None
    queues = None

    def __init__(self, queues, loop):
        self.ready = Event(loop=loop)
        self.stop_event = Event(loop=loop)
        self.queues = queues
        self.loop = loop

    async def start(self):
        if self.stop_event.is_set():
            raise QueueError("Socket already stopped.")

        await self.do_action("start")
        self.ready.set()

    @dies_on_stop_event
    async def get(self):
        raise NotImplementedError()

    @dies_on_stop_event
    async def put(self, data):
        await self.setup()
        await self.ready.wait()
        await self.do_action("put", (data,))

    async def setup(self):
        """Setup the client."""
        if not self.ready.is_set():
            await self.start()

    async def stop(self):
        """Stop queue."""
        self.ready.clear()
        self.stop_event.set()

        await self.do_action("stop")

    async def do_action(self, name, args=()):
        coroutines = [getattr(i, name) for i in self.queues]
        tasks = [i(*args) for i in coroutines]

        await wait(tasks, loop=self.loop)
예제 #9
0
    def __init__(self, size=1, queue_timeout=15.0, *, loop=None, **config):
        assert size > 0, 'DBPool.size must be greater than 0'
        if size < 1:
            raise ValueError('DBPool.size is less than 1, '
                             'connections won"t be established')
        self._pool = set()
        self._busy_items = set()
        self._size = size
        self._pending_futures = deque()
        self._queue_timeout = queue_timeout
        self._loop = loop or asyncio.get_event_loop()
        self.config = config

        self._shutdown_event = Event(loop=self._loop)
        self._shutdown_event.set()
예제 #10
0
파일: coro.py 프로젝트: vivisect/synapse
async def event_wait(event: asyncio.Event, timeout=None):
    '''
    Wait on an an asyncio event with an optional timeout

    Returns:
        true if the event got set, None if timed out
    '''
    if timeout is None:
        await event.wait()
        return True

    try:
        await asyncio.wait_for(event.wait(), timeout)
    except asyncio.TimeoutError:
        return False
    return True
예제 #11
0
 def __init__(self):
     super().__init__()
     self.cancel_event = Event()
예제 #12
0
class SubscribeListener(SubscribeCallback):
    def __init__(self):
        self.connected = False
        self.connected_event = Event()
        self.disconnected_event = Event()
        self.presence_queue = Queue()
        self.message_queue = Queue()
        self.error_queue = Queue()

    def status(self, pubnub, status):
        if utils.is_subscribed_event(status) and not self.connected_event.is_set():
            self.connected_event.set()
        elif utils.is_unsubscribed_event(status) and not self.disconnected_event.is_set():
            self.disconnected_event.set()
        elif status.is_error():
            self.error_queue.put_nowait(status.error_data.exception)

    def message(self, pubnub, message):
        self.message_queue.put_nowait(message)

    def presence(self, pubnub, presence):
        self.presence_queue.put_nowait(presence)

    @asyncio.coroutine
    def _wait_for(self, coro):
        scc_task = asyncio.ensure_future(coro)
        err_task = asyncio.ensure_future(self.error_queue.get())

        yield from asyncio.wait([
            scc_task,
            err_task
        ], return_when=asyncio.FIRST_COMPLETED)

        if err_task.done() and not scc_task.done():
            if not scc_task.cancelled():
                scc_task.cancel()
            raise err_task.result()
        else:
            if not err_task.cancelled():
                err_task.cancel()
            return scc_task.result()

    @asyncio.coroutine
    def wait_for_connect(self):
        if not self.connected_event.is_set():
            yield from self._wait_for(self.connected_event.wait())
        else:
            raise Exception("instance is already connected")

    @asyncio.coroutine
    def wait_for_disconnect(self):
        if not self.disconnected_event.is_set():
            yield from self._wait_for(self.disconnected_event.wait())
        else:
            raise Exception("instance is already disconnected")

    @asyncio.coroutine
    def wait_for_message_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = yield from self._wait_for(self.message_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.message_queue.task_done()

    @asyncio.coroutine
    def wait_for_presence_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                env = yield from self._wait_for(self.presence_queue.get())
                if env.channel in channel_names:
                    return env
                else:
                    continue
            finally:
                self.presence_queue.task_done()
예제 #13
0
    async def _sync(self) -> OutboundMessageGenerator:
        """
        Performs a full sync of the blockchain.
            - Check which are the heaviest tips
            - Request headers for the heaviest
            - Verify the weight of the tip, using the headers
            - Find the fork point to see where to start downloading blocks
            - Blacklist peers that provide invalid blocks
            - Sync blockchain up to heads (request blocks in batches)
        """
        log.info("Starting to perform sync with peers.")
        log.info("Waiting to receive tips from peers.")
        # TODO: better way to tell that we have finished receiving tips
        await asyncio.sleep(5)
        highest_weight: uint64 = uint64(0)
        tip_block: FullBlock
        tip_height = 0
        sync_start_time = time.time()

        # Based on responses from peers about the current heads, see which head is the heaviest
        # (similar to longest chain rule).

        potential_tips: List[
            Tuple[bytes32, FullBlock]
        ] = self.store.get_potential_tips_tuples()
        log.info(f"Have collected {len(potential_tips)} potential tips")
        for header_hash, potential_tip_block in potential_tips:
            if potential_tip_block.header_block.challenge is None:
                raise ValueError(
                    f"Invalid tip block {potential_tip_block.header_hash} received"
                )
            if potential_tip_block.header_block.challenge.total_weight > highest_weight:
                highest_weight = potential_tip_block.header_block.challenge.total_weight
                tip_block = potential_tip_block
                tip_height = potential_tip_block.header_block.challenge.height
        if highest_weight <= max(
            [t.weight for t in self.blockchain.get_current_tips()]
        ):
            log.info("Not performing sync, already caught up.")
            return

        assert tip_block
        log.info(f"Tip block {tip_block.header_hash} tip height {tip_block.height}")

        for height in range(0, tip_block.height + 1):
            self.store.set_potential_headers_received(uint32(height), Event())
            self.store.set_potential_blocks_received(uint32(height), Event())
            self.store.set_potential_hashes_received(Event())

        timeout = 200
        sleep_interval = 10
        total_time_slept = 0

        while True:
            if total_time_slept > timeout:
                raise TimeoutError("Took too long to fetch header hashes.")
            if self._shut_down:
                return
            # Download all the header hashes and find the fork point
            request = peer_protocol.RequestAllHeaderHashes(tip_block.header_hash)
            yield OutboundMessage(
                NodeType.FULL_NODE,
                Message("request_all_header_hashes", request),
                Delivery.RANDOM,
            )
            try:
                phr = self.store.get_potential_hashes_received()
                assert phr is not None
                await asyncio.wait_for(
                    phr.wait(), timeout=sleep_interval,
                )
                break
            except concurrent.futures.TimeoutError:
                total_time_slept += sleep_interval
                log.warning("Did not receive desired header hashes")

        # Finding the fork point allows us to only download headers and blocks from the fork point
        header_hashes = self.store.get_potential_hashes()
        fork_point_height: uint32 = self.blockchain.find_fork_point(header_hashes)
        fork_point_hash: bytes32 = header_hashes[fork_point_height]
        log.info(f"Fork point: {fork_point_hash} at height {fork_point_height}")

        # Now, we download all of the headers in order to verify the weight, in batches
        headers: List[HeaderBlock] = []

        # Download headers in batches. We download a few batches ahead in case there are delays or peers
        # that don't have the headers that we need.
        last_request_time: float = 0
        highest_height_requested: uint32 = uint32(0)
        request_made: bool = False
        for height_checkpoint in range(
            fork_point_height + 1, tip_height + 1, self.config["max_headers_to_send"]
        ):
            end_height = min(
                height_checkpoint + self.config["max_headers_to_send"], tip_height + 1
            )

            total_time_slept = 0
            while True:
                if self._shut_down:
                    return
                if total_time_slept > timeout:
                    raise TimeoutError("Took too long to fetch blocks")

                # Request batches that we don't have yet
                for batch in range(0, self.config["num_sync_batches"]):
                    batch_start = (
                        height_checkpoint + batch * self.config["max_headers_to_send"]
                    )
                    batch_end = min(
                        batch_start + self.config["max_headers_to_send"], tip_height + 1
                    )

                    if batch_start > tip_height:
                        # We have asked for all blocks
                        break

                    blocks_missing = any(
                        [
                            not (
                                self.store.get_potential_headers_received(uint32(h))
                            ).is_set()
                            for h in range(batch_start, batch_end)
                        ]
                    )
                    if (
                        time.time() - last_request_time > sleep_interval
                        and blocks_missing
                    ) or (batch_end - 1) > highest_height_requested:
                        # If we are missing header blocks in this batch, and we haven't made a request in a while,
                        # Make a request for this batch. Also, if we have never requested this batch, make
                        # the request
                        if batch_end - 1 > highest_height_requested:
                            highest_height_requested = batch_end - 1

                        request_made = True
                        request_hb = peer_protocol.RequestHeaderBlocks(
                            tip_block.header_block.header.get_hash(),
                            [uint32(h) for h in range(batch_start, batch_end)],
                        )
                        log.info(f"Requesting header blocks {batch_start, batch_end}.")
                        yield OutboundMessage(
                            NodeType.FULL_NODE,
                            Message("request_header_blocks", request_hb),
                            Delivery.RANDOM,
                        )
                if request_made:
                    # Reset the timer for requests, so we don't overload other peers with requests
                    last_request_time = time.time()
                    request_made = False

                # Wait for the first batch (the next "max_blocks_to_send" blocks to arrive)
                awaitables = [
                    (self.store.get_potential_headers_received(uint32(height))).wait()
                    for height in range(height_checkpoint, end_height)
                ]
                future = asyncio.gather(*awaitables, return_exceptions=True)
                try:
                    await asyncio.wait_for(future, timeout=sleep_interval)
                    break
                except concurrent.futures.TimeoutError:
                    try:
                        await future
                    except asyncio.CancelledError:
                        pass
                    total_time_slept += sleep_interval
                    log.info(f"Did not receive desired header blocks")

        for h in range(fork_point_height + 1, tip_height + 1):
            header = self.store.get_potential_header(uint32(h))
            assert header is not None
            headers.append(header)

        log.info(f"Downloaded headers up to tip height: {tip_height}")
        if not verify_weight(
            tip_block.header_block, headers, self.blockchain.headers[fork_point_hash],
        ):
            raise errors.InvalidWeight(
                f"Weight of {tip_block.header_block.header.get_hash()} not valid."
            )

        log.info(
            f"Validated weight of headers. Downloaded {len(headers)} headers, tip height {tip_height}"
        )
        assert tip_height == fork_point_height + len(headers)
        self.store.clear_potential_headers()
        headers.clear()

        # Download blocks in batches, and verify them as they come in. We download a few batches ahead,
        # in case there are delays.
        last_request_time = 0
        highest_height_requested = uint32(0)
        request_made = False
        for height_checkpoint in range(
            fork_point_height + 1, tip_height + 1, self.config["max_blocks_to_send"]
        ):
            end_height = min(
                height_checkpoint + self.config["max_blocks_to_send"], tip_height + 1
            )

            total_time_slept = 0
            while True:
                if self._shut_down:
                    return
                if total_time_slept > timeout:
                    raise TimeoutError("Took too long to fetch blocks")

                # Request batches that we don't have yet
                for batch in range(0, self.config["num_sync_batches"]):
                    batch_start = (
                        height_checkpoint + batch * self.config["max_blocks_to_send"]
                    )
                    batch_end = min(
                        batch_start + self.config["max_blocks_to_send"], tip_height + 1
                    )

                    if batch_start > tip_height:
                        # We have asked for all blocks
                        break

                    blocks_missing = any(
                        [
                            not (
                                self.store.get_potential_blocks_received(uint32(h))
                            ).is_set()
                            for h in range(batch_start, batch_end)
                        ]
                    )
                    if (
                        time.time() - last_request_time > sleep_interval
                        and blocks_missing
                    ) or (batch_end - 1) > highest_height_requested:
                        # If we are missing blocks in this batch, and we haven't made a request in a while,
                        # Make a request for this batch. Also, if we have never requested this batch, make
                        # the request
                        log.info(
                            f"Requesting sync blocks {[i for i in range(batch_start, batch_end)]}"
                        )
                        if batch_end - 1 > highest_height_requested:
                            highest_height_requested = batch_end - 1
                        request_made = True
                        request_sync = peer_protocol.RequestSyncBlocks(
                            tip_block.header_block.header.header_hash,
                            [
                                uint32(height)
                                for height in range(batch_start, batch_end)
                            ],
                        )
                        yield OutboundMessage(
                            NodeType.FULL_NODE,
                            Message("request_sync_blocks", request_sync),
                            Delivery.RANDOM,
                        )
                if request_made:
                    # Reset the timer for requests, so we don't overload other peers with requests
                    last_request_time = time.time()
                    request_made = False

                # Wait for the first batch (the next "max_blocks_to_send" blocks to arrive)
                awaitables = [
                    (self.store.get_potential_blocks_received(uint32(height))).wait()
                    for height in range(height_checkpoint, end_height)
                ]
                future = asyncio.gather(*awaitables, return_exceptions=True)
                try:
                    await asyncio.wait_for(future, timeout=sleep_interval)
                    break
                except concurrent.futures.TimeoutError:
                    try:
                        await future
                    except asyncio.CancelledError:
                        pass
                    total_time_slept += sleep_interval
                    log.info("Did not receive desired blocks")

            # Verifies this batch, which we are guaranteed to have (since we broke from the above loop)
            blocks = []
            for height in range(height_checkpoint, end_height):
                b: Optional[FullBlock] = await self.store.get_potential_block(
                    uint32(height)
                )
                assert b is not None
                blocks.append(b)

            validation_start_time = time.time()
            prevalidate_results = await self.blockchain.pre_validate_blocks(blocks)
            index = 0
            for height in range(height_checkpoint, end_height):
                if self._shut_down:
                    return
                block: Optional[FullBlock] = await self.store.get_potential_block(
                    uint32(height)
                )
                assert block is not None

                prev_block: Optional[FullBlock] = await self.store.get_potential_block(
                    uint32(height - 1)
                )
                if prev_block is None:
                    prev_block = await self.store.get_block(block.prev_header_hash)
                assert prev_block is not None

                # The block gets permanantly added to the blockchain
                validated, pos = prevalidate_results[index]
                index += 1

                async with self.store.lock:
                    result = await self.blockchain.receive_block(
                        block, prev_block.header_block, validated, pos
                    )
                    if (
                        result == ReceiveBlockResult.INVALID_BLOCK
                        or result == ReceiveBlockResult.DISCONNECTED_BLOCK
                    ):
                        raise RuntimeError(f"Invalid block {block.header_hash}")

                    # Always immediately add the block to the database, after updating blockchain state
                    await self.store.add_block(block)

                assert (
                    max([h.height for h in self.blockchain.get_current_tips()])
                    >= height
                )
                self.store.set_proof_of_time_estimate_ips(
                    self.blockchain.get_next_ips(block.header_block)
                )
            log.info(
                f"Took {time.time() - validation_start_time} seconds to validate and add blocks "
                f"{height_checkpoint} to {end_height}."
            )
        assert max([h.height for h in self.blockchain.get_current_tips()]) == tip_height
        log.info(
            f"Finished sync up to height {tip_height}. Total time: "
            f"{round((time.time() - sync_start_time)/60, 2)} minutes."
        )
예제 #14
0
class StorageAdaptor(LockStorage):
    def __init__(self, store):
        self.store = store
        self.kill_active = False
        self._storage_operations_in_process = 0
        self._commands_in_process = 0
        self._data_dumped = False
        self._maintenance_event = Event()

        signal.signal(signal.SIGINT, self._on_kill_requested)
        signal.signal(signal.SIGTERM, self._on_kill_requested)

        super().__init__()

    def terminate(self):
        self._on_kill_requested()

    def on_storage_operation_start(self):
        self._storage_operations_in_process += 1

    def on_storage_operation_end(self):
        self._storage_operations_in_process -= 1
        if self._storage_operations_in_process < 0:
            raise Exception('Storage active operations counter broken')
        if self.kill_active and not self._storage_operations_in_process:
            # Last storage operation ended, good point to dump storage data
            # then exit() after sending replies to clients
            self.dump_before_die()

    def is_commands_in_process(self):
        return self._commands_in_process > 0

    async def on_command_start(self):
        await self._maintenance_event.wait()
        self._commands_in_process += 1

    def on_command_end(self):
        self._commands_in_process -= 1
        if self.kill_active and not self._commands_in_process:
            if self._storage_operations_in_process:
                logger.error(
                    'Last command executed, but storage operations counter has value'
                )
            self.die()

    def _on_kill_requested(self, *args, **kwargs):
        # Must wait for active commands to complete, then dump data, close all connections and exit
        self.kill_active = True
        logger.info('Received terminate signal %s %s', args, kwargs)
        if not self._storage_operations_in_process:
            # Good time to dump and exit(), but there is chance that some client's will not get their responses
            self.dump_before_die()

        # In case any storage operations is in progress, we will dump data after last operations completes
        # and will exit() when last reply is sent
        if not self._commands_in_process:
            self.die()

    def dump_before_die(self):
        # Just dump data, because all storage operations ended
        self.store.dump()
        self._data_dumped = True
        logger.debug('Dumped data')

    def die(self):
        # All commands processed, can exit safely
        if not self._data_dumped:
            self.store.dump()
        logger.debug('Exit')
        exit(1)

    def acquire(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.acquire(*args, **kwargs)

    def release(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.release(*args, **kwargs)

    def release_all(self, client_id):
        with StorageOperationGuard(self):
            return self.store.release_all(client_id=client_id,
                                          timeout=self.release_all_timeout)

    def unrelease_all(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.unrelease_all(*args, **kwargs)

    def locked(self, *args, **kwargs):
        return self.store.locked(*args, **kwargs)

    def set_client_last_address(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.set_client_last_address(*args, **kwargs)

    def add_signal(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.add_signal(*args, **kwargs)

    def has_signal(self, *args, **kwargs):
        return self.store.has_signal(*args, **kwargs)

    def remove_signal(self, *args, **kwargs):
        with StorageOperationGuard(self):
            return self.store.remove_signal(*args, **kwargs)

    def get_client_last_address(self, *args, **kwargs):
        return self.store.get_client_last_address(*args, **kwargs)

    def find(self, *args, **kwargs):
        return self.store.find(*args, **kwargs)

    def dump(self):
        return self.store.dump()

    def load_dump(self):
        return self.store.load_dump()

    def clear_dump(self):
        return self.store.clear_dump()

    def stats(self):
        return self.store.stats()

    def maintenance(self, *args, **kwargs):
        logger.debug('Doing maintenance')
        try:
            self._maintenance_event.clear()
            return self.store.maintenance(*args, **kwargs)
        finally:
            self._maintenance_event.set()
예제 #15
0
 async def listen(driver_ready: asyncio.Event) -> None:
     driver_ready.set()
     await asyncio.sleep(30)
     assert False, "Listen wasn't canceled!"
예제 #16
0
async def worker(
    id_,
    task_q: asyncio.Queue,
    send: asyncio.Queue,
    recv: asyncio.Queue,
    exclude: set,
    used: asyncio.Queue,
    unreachable: asyncio.Queue,
    event: asyncio.Event,
    delimiter: bytes,
    timeout=None,
):
    async def worker_handler(
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
        port: int,
        handle_finished: asyncio.Event,
    ):
        print(SharedData.green(f"[SS{id_:2}][INFO] --- IN HANDLER ---"))
        await tcp_send(p, writer, delimiter, timeout=timeout)
        print(SharedData.green(f"[SS{id_:2}][INFO] Port {port} is open."))

        writer.close()
        await writer.wait_closed()
        handle_finished.set()

    try:
        while not task_q.empty() and not event.is_set():

            # receive worker announcement.
            try:
                worker_id = await asyncio.wait_for(recv.get(), timeout)
                recv.task_done()
            except asyncio.TimeoutError:
                print(SharedData.red(f"[SS{id_:2}][Warn] Timeout."))
                continue

            print(f"[SS{id_:2}][INFO] Worker {worker_id} available.")

            # get next work.
            print(f"[SS{id_:2}][INFO] Getting new port.")

            # if timeout getting port, either task is empty or just coroutine delayed.
            try:
                p: int = await asyncio.wait_for(task_q.get(), timeout)
                task_q.task_done()
            except asyncio.TimeoutError:
                if task_q.empty():
                    break
                else:
                    await recv.put(worker_id)
                    continue
                    # put back in and run again.

            # check if port is in blacklist.
            if p in exclude:
                print(SharedData.cyan(f"[SS{id_:2}][INFO] Skipping Port {p}."))
                continue

            print(f"[SS{id_:2}][INFO] Sending port {p} to client.")
            await send.put(p)

            handle_ev = asyncio.Event()

            print(f"[SS{id_:2}][INFO] Trying to serve port {p}.")
            try:
                # child_sock = await asyncio.wait_for(asyncio.start_server(
                #     lambda r, w: worker_handler(r, w, p, handle_ev), port=p), TIMEOUT_FACTOR)

                child_sock = await asyncio.start_server(
                    lambda r, w: worker_handler(r, w, p, handle_ev), port=p)

            # except asyncio.TimeoutError:
            #     # not sure why start_server gets timeout.
            #     # maybe I need to control number of task so opening server don't hang.
            #     print(SharedData.red(f"[SS{id_:2}][Warn] Port {p} timeout while opening."))
            #     await unreachable.put(p)

            except AssertionError:
                print(
                    SharedData.red(
                        f"[SS{id_:2}][INFO] Port {p} assertion failed!"))
                await unreachable.put(p)

            except OSError:
                print(SharedData.red(f"[SS{id_:2}][Warn] Port {p} in use."))
                await used.put(p)

            else:
                try:
                    await child_sock.start_serving()
                    await asyncio.wait_for(handle_ev.wait(), timeout)

                except asyncio.TimeoutError:
                    print(
                        SharedData.red(f"[SS{id_:2}][Warn] Port {p} timeout."))
                    await unreachable.put(p)
                finally:
                    child_sock.close()
                    await child_sock.wait_closed()

        # Send end signal to client.
        # first worker catching this signal will go offline.

        print(SharedData.cyan(f"[SS{id_:2}][INFO] Done. Sending stop signal."))
        await send.put("DONE"
                       )  # causing int type-error on client side workers.

    except Exception:
        # trigger event to stop all threads.
        print(SharedData.red(f"[SS{id_:2}][CRIT] Exception Event set!."))
        event.set()
        raise

    if event.is_set():
        print(SharedData.bold(f"[SS{id_:2}][WARN] Task Finished by event."))
    else:
        print(SharedData.bold(f"[SS{id_:2}][INFO] Task Finished."))
예제 #17
0
파일: protocol.py 프로젝트: travigd/grpclib
class Stream:
    """
    API for working with streams, used by clients and request handlers
    """
    id: Optional[int] = None

    # stats
    created: Optional[float] = None
    data_sent = 0
    data_received = 0

    def __init__(self,
                 connection: Connection,
                 h2_connection: H2Connection,
                 transport: Transport,
                 *,
                 stream_id: Optional[int] = None,
                 wrapper: Optional[Wrapper] = None) -> None:
        self.connection = connection
        self._h2_connection = h2_connection
        self._transport = transport
        self.wrapper = wrapper

        if stream_id is not None:
            self.init_stream(stream_id, self.connection)

        self.window_updated = Event()
        self.headers: Optional['_Headers'] = None
        self.headers_received = Event()
        self.trailers: Optional['_Headers'] = None
        self.trailers_received = Event()

    def init_stream(self, stream_id: int, connection: Connection) -> None:
        self.id = stream_id
        self.buffer = Buffer(partial(connection.ack, self.id))

        self.connection.streams_started += 1
        self.created = self.connection.last_stream_created = time.monotonic()

    async def recv_headers(self) -> _Headers:
        if self.headers is None:
            await self.headers_received.wait()
        assert self.headers is not None
        return self.headers

    async def recv_data(self, size: int) -> bytes:
        return await self.buffer.read(size)

    async def recv_trailers(self) -> _Headers:
        if self.trailers is None:
            await self.trailers_received.wait()
        assert self.trailers is not None
        return self.trailers

    async def send_request(
        self,
        headers: _Headers,
        end_stream: bool = False,
        *,
        _processor: 'EventsProcessor',
    ) -> Callable[[], None]:
        assert self.id is None, self.id
        while True:
            # this is the first thing we should check before even trying to
            # create new stream, because this wait() can be cancelled by timeout
            # and we wouldn't need to create new stream at all
            await self.connection.write_ready.wait()

            # `get_next_available_stream_id()` should be as close to
            # `connection.send_headers()` as possible, without any async
            # interruptions in between, see the docs on the
            # `get_next_available_stream_id()` method
            stream_id = self._h2_connection.get_next_available_stream_id()
            try:
                self._h2_connection.send_headers(stream_id,
                                                 headers,
                                                 end_stream=end_stream)
            except TooManyStreamsError:
                # we're going to wait until any of currently opened streams will
                # be closed, and we will be able to open a new one
                # TODO: maybe implement FIFO for waiters, but this limit
                #       shouldn't be reached in a normal case, so why bother
                # TODO: maybe we should raise an exception here instead of
                #       waiting, if timeout wasn't set for the current request
                self.connection.stream_close_waiter.clear()
                await self.connection.stream_close_waiter.wait()
                # while we were trying to create a new stream, write buffer
                # can became full, so we need to repeat checks from checking
                # if we can write() data
                continue
            else:
                self.init_stream(stream_id, self.connection)
                release_stream = _processor.register(self)
                self._transport.write(self._h2_connection.data_to_send())
                self.connection.headers_send_process()
                return release_stream

    async def send_headers(
        self,
        headers: _Headers,
        end_stream: bool = False,
    ) -> None:
        assert self.id is not None
        await self.connection.write_ready.wait()

        # Workaround for the H2Connection.send_headers method, which will try
        # to create a new stream if it was removed earlier from the
        # H2Connection.streams, and therefore will raise StreamIDTooLowError
        if self.id not in self._h2_connection.streams:
            raise StreamClosedError(self.id)

        self._h2_connection.send_headers(self.id,
                                         headers,
                                         end_stream=end_stream)
        self._transport.write(self._h2_connection.data_to_send())
        self.connection.headers_send_process()

    async def send_data(self, data: bytes, end_stream: bool = False) -> None:
        f = BytesIO(data)
        f_pos, f_last = 0, len(data)

        while True:
            await self.connection.write_ready.wait()

            window = self._h2_connection.local_flow_control_window(self.id)
            # window can become negative
            if not window > 0:
                self.window_updated.clear()
                await self.window_updated.wait()
                # during "await" above other streams were able to send data and
                # decrease current window size, so try from the beginning
                continue

            max_frame_size = self._h2_connection.max_outbound_frame_size
            f_chunk = f.read(min(window, max_frame_size, f_last - f_pos))
            f_chunk_len = len(f_chunk)
            f_pos = f.tell()

            if f_pos == f_last:
                self._h2_connection.send_data(self.id,
                                              f_chunk,
                                              end_stream=end_stream)
                self._transport.write(self._h2_connection.data_to_send())
                self.data_sent += f_chunk_len
                self.connection.data_sent += f_chunk_len
                self.connection.data_send_process()
                break
            else:
                self._h2_connection.send_data(self.id, f_chunk)
                self._transport.write(self._h2_connection.data_to_send())
                self.data_sent += f_chunk_len
                self.connection.data_sent += f_chunk_len
                self.connection.data_send_process()

    async def end(self) -> None:
        await self.connection.write_ready.wait()
        self._h2_connection.end_stream(self.id)
        self._transport.write(self._h2_connection.data_to_send())

    async def reset(self,
                    error_code: ErrorCodes = ErrorCodes.NO_ERROR) -> None:
        await self.connection.write_ready.wait()
        self._h2_connection.reset_stream(self.id, error_code=error_code)
        self._transport.write(self._h2_connection.data_to_send())

    def reset_nowait(
        self,
        error_code: ErrorCodes = ErrorCodes.NO_ERROR,
    ) -> None:
        self._h2_connection.reset_stream(self.id, error_code=error_code)
        if self.connection.write_ready.is_set():
            self._transport.write(self._h2_connection.data_to_send())

    def __ended__(self) -> None:
        self.buffer.eof()

    def __terminated__(self, reason: str) -> None:
        if self.wrapper is not None:
            self.wrapper.cancel(StreamTerminatedError(reason))

    @property
    def closable(self) -> bool:
        if self._h2_connection.state_machine.state is ConnectionState.CLOSED:
            return False
        stream = self._h2_connection.streams.get(self.id)
        if stream is None:
            return False
        return not stream.closed
예제 #18
0
파일: protocol.py 프로젝트: travigd/grpclib
class Connection:
    """
    Holds connection state (write_ready), and manages
    H2Connection <-> Transport communication
    """
    # stats
    streams_started = 0
    streams_succeeded = 0
    streams_failed = 0
    data_sent = 0
    data_received = 0
    messages_sent = 0
    messages_received = 0
    last_stream_created: Optional[float] = None
    last_data_sent: Optional[float] = None
    last_data_received: Optional[float] = None
    last_message_sent: Optional[float] = None
    last_message_received: Optional[float] = None
    last_ping_sent: Optional[float] = None
    ping_count_in_sequence: int = 0
    _ping_handle: Optional[TimerHandle] = None
    _close_by_ping_handler: Optional[TimerHandle] = None

    def __init__(
        self,
        connection: H2Connection,
        transport: Transport,
        *,
        config: Configuration,
    ) -> None:
        self._connection = connection
        self._transport = transport
        self._config = config

        self.write_ready = Event()
        self.write_ready.set()

        self.stream_close_waiter = Event()

    def feed(self, data: bytes) -> List[H2Event]:
        return self._connection.receive_data(data)  # type: ignore

    def ack(self, stream_id: int, size: int) -> None:
        if size:
            self._connection.acknowledge_received_data(size, stream_id)
            self.flush()

    def pause_writing(self) -> None:
        self.write_ready.clear()

    def resume_writing(self) -> None:
        self.write_ready.set()

    def create_stream(
        self,
        *,
        stream_id: Optional[int] = None,
        wrapper: Optional[Wrapper] = None,
    ) -> 'Stream':
        return Stream(self,
                      self._connection,
                      self._transport,
                      stream_id=stream_id,
                      wrapper=wrapper)

    def flush(self) -> None:
        data = self._connection.data_to_send()
        if data:
            self._transport.write(data)

    def initialize(self) -> None:
        if self._config._keepalive_time is not None:
            self._ping_handle = asyncio.get_event_loop().call_later(
                self._config._keepalive_time, self._ping)

    def close(self) -> None:
        if hasattr(self, '_transport'):
            self._transport.close()
            # remove cyclic references to improve memory usage
            del self._transport
            if hasattr(self._connection, '_frame_dispatch_table'):
                del self._connection._frame_dispatch_table
        if self._ping_handle is not None:
            self._ping_handle.cancel()
        if self._close_by_ping_handler is not None:
            self._close_by_ping_handler.cancel()

    def _is_need_send_ping(self) -> bool:
        assert self._config._keepalive_time is not None

        if not self._config._keepalive_permit_without_calls:
            if not any(s.open for s in self._connection.streams.values()):
                return False

        if self._config._http2_max_pings_without_data != 0 and \
                self.ping_count_in_sequence >= \
                self._config._http2_max_pings_without_data:
            return False

        if self.last_ping_sent is not None and \
                time.monotonic() - self.last_ping_sent < \
                self._config._http2_min_sent_ping_interval_without_data:
            return False

        return True

    def _ping(self) -> None:
        assert self._config._keepalive_time is not None
        if self._is_need_send_ping():
            log.debug('send ping')
            data = struct.pack('!Q', int(time.monotonic() * 10**6))
            self._connection.ping(data)
            self.flush()
            self.last_ping_sent = time.monotonic()
            self.ping_count_in_sequence += 1
            if self._close_by_ping_handler is None:
                self._close_by_ping_handler = asyncio.get_event_loop().\
                    call_later(
                        self._config._keepalive_timeout,
                        self.close
                    )
        self._ping_handle = asyncio.get_event_loop().call_later(
            self._config._keepalive_time, self._ping)

    def headers_send_process(self) -> None:
        self.ping_count_in_sequence = 0

    def data_send_process(self) -> None:
        self.ping_count_in_sequence = 0
        self.last_data_sent = time.monotonic()

    def ping_ack_process(self) -> None:
        if self._close_by_ping_handler is not None:
            self._close_by_ping_handler.cancel()
            self._close_by_ping_handler = None
예제 #19
0
파일: protocol.py 프로젝트: travigd/grpclib
 def __init__(self, limit: Optional[int] = None) -> None:
     self._limit = limit
     self._current = 0
     self._release = Event()
예제 #20
0
class MapAsyncIterator:
    """Map an AsyncIterable over a callback function.

    Given an AsyncIterable and a callback function, return an AsyncIterator which
    produces values mapped via calling the callback function.

    When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
    be closed.
    """
    def __init__(
        self,
        iterable: AsyncIterable,
        callback: Callable,
        reject_callback: Callable = None,
    ) -> None:
        self.iterator = iterable.__aiter__()
        self.callback = callback
        self.reject_callback = reject_callback
        self._close_event = Event()

    def __aiter__(self):
        return self

    async def __anext__(self):
        if self.is_closed:
            if not isasyncgen(self.iterator):
                raise StopAsyncIteration
            value = await self.iterator.__anext__()
            result = self.callback(value)

        else:
            aclose = ensure_future(self._close_event.wait())
            anext = ensure_future(self.iterator.__anext__())

            done, pending = await wait([aclose, anext],
                                       return_when=FIRST_COMPLETED)
            for task in pending:
                task.cancel()

            if aclose.done():
                raise StopAsyncIteration

            error = anext.exception()
            if error:
                if not self.reject_callback or isinstance(
                        error, (StopAsyncIteration, GeneratorExit)):
                    raise error
                result = self.reject_callback(error)
            else:
                value = anext.result()
                result = self.callback(value)

        return await result if isawaitable(result) else result

    async def athrow(self, type_, value=None, traceback=None):
        if not self.is_closed:
            athrow = getattr(self.iterator, "athrow", None)
            if athrow:
                await athrow(type_, value, traceback)
            else:
                self.is_closed = True
                if value is None:
                    if traceback is None:
                        raise type_
                    value = type_()
                if traceback is not None:
                    value = value.with_traceback(traceback)
                raise value

    async def aclose(self):
        if not self.is_closed:
            aclose = getattr(self.iterator, "aclose", None)
            if aclose:
                try:
                    await aclose()
                except RuntimeError:
                    pass
            self.is_closed = True

    @property
    def is_closed(self) -> bool:
        return self._close_event.is_set()

    @is_closed.setter
    def is_closed(self, value: bool) -> None:
        if value:
            self._close_event.set()
        else:
            self._close_event.clear()
예제 #21
0
 def __init__(self):
     self._future = Future()
     self._complete = Event()
     self._listeners = []
예제 #22
0
 def __init__(self, maxsize=0, *, loop=None):
     super().__init__(maxsize=maxsize, loop=loop)
     self.empty_event = Event(loop=self._loop)
     self.empty_event.set()
예제 #23
0
class MQueue(Queue):
    """Notice: the item which get() method return is unordered"""
    def __init__(self, maxsize=0, *, loop=None):
        super().__init__(maxsize=maxsize, loop=loop)
        self.empty_event = Event(loop=self._loop)
        self.empty_event.set()

    def _init(self, maxsize):
        # make item unique
        self._queue = set()

    def _get(self):
        # Warn: shouldn't call this method derictly
        return self._queue.pop()

    def _put(self, item):
        # Warn: shouldn't call this method derictly
        self._queue.add(item)

    def __contains__(self, item):
        return item in self._queue

    def get_nowait(self):
        item = super().get_nowait()
        if self.empty():
            self.empty_event.set()
        return item

    @coroutine
    def get(self, maxcount=1):
        """pop almost maxcount items from queue,
        blocking when queue is empty
        """
        while self.empty():
            getter = self._loop.create_future()
            self._getters.append(getter)
            try:
                yield from getter
            except:
                getter.cancel()  # Just in case getter is not done yet.
                if not self.empty() and not getter.cancelled():
                    # We were woken up by put_nowait(), but can't take
                    # the call.  Wake up the next in line.
                    self._wakeup_next(self._getters)
                raise
        items = []
        while not self.empty() and maxcount > 0:
            items.append(self._get())
            maxcount -= 1
        if self.empty():
            # notify others queue is empty
            self.empty_event.set()
        return items

    @coroutine
    def get_all(self):
        """get all items form queue but not remove,
        blocking when queue is empty
        """
        while self.empty():
            getter = self._loop.create_future()
            self._getters.append(getter)
            try:
                yield from getter
            except:
                getter.cancel()
                if not self.empty() and not getter.cancelled():
                    self._wakeup_next(self._getters)
                raise
        return list(self._queue)

    def put_nowait(self, item):
        if item in self._queue:
            return
        super().put_nowait(item)
        self.empty_event.clear()
예제 #24
0
파일: compound.py 프로젝트: iakinsey/illume
 def __init__(self, queues, loop):
     self.ready = Event(loop=loop)
     self.stop_event = Event(loop=loop)
     self.queues = queues
     self.loop = loop
예제 #25
0
파일: pool.py 프로젝트: tzoiker/aiomisc
        async def handler(start_event: asyncio.Event) -> None:
            log.debug("Starting to handle client")

            packet_type, salt = await receive()
            assert packet_type == PacketTypes.AUTH_SALT

            packet_type, digest = await receive()
            assert packet_type == PacketTypes.AUTH_DIGEST

            hasher = HASHER()
            hasher.update(salt)
            hasher.update(self.__cookie)

            if digest != hasher.digest():
                exc = AuthenticationError("Invalid cookie")
                await send(PacketTypes.EXCEPTION, exc)
                raise exc

            await send(PacketTypes.AUTH_OK, True)

            log.debug("Client authorized")

            packet_type, identity = await receive()
            assert packet_type == PacketTypes.IDENTITY
            process = self.__spawning.pop(identity)
            starting: asyncio.Future = self.__starting.pop(identity)

            if self.initializer is not None:
                initializer_done = self.__create_future()

                await step(self.initializer, self.initializer_args,
                           dict(self.initializer_kwargs), initializer_done)

                try:
                    await initializer_done
                except Exception as e:
                    starting.set_exception(e)
                    raise
                else:
                    starting.set_result(None)
                finally:
                    start_event.set()
            else:
                starting.set_result(None)
                start_event.set()

            while True:
                func: Callable
                args: Tuple[Any, ...]
                kwargs: Dict[str, Any]
                result_future: asyncio.Future
                process_future: asyncio.Future

                (
                    func,
                    args,
                    kwargs,
                    result_future,
                    process_future,
                ) = await self.tasks.get()

                try:
                    if process_future.done():
                        continue

                    process_future.set_result(process)

                    if result_future.done():
                        continue

                    await step(func, args, kwargs, result_future)
                except asyncio.IncompleteReadError:
                    await self.__wait_process(process)
                    self.__on_exit(process)

                    result_future.set_exception(
                        ProcessError(
                            "Process {!r} exited with code {!r}".format(
                                process,
                                process.returncode,
                            ), ), )
                    break
                except Exception as e:
                    if not result_future.done():
                        self.loop.call_soon(result_future.set_exception, e)

                    if not writer.is_closing():
                        self.loop.call_soon(writer.close)

                    await self.__wait_process(process)
                    self.__on_exit(process)

                    raise
예제 #26
0
파일: stores.py 프로젝트: def-nn/tornadose
class RedisStore(BaseStore):
    """Publish data via a Redis backend.

    This data store works in a similar manner as
    :class:`DataStore`. The primary advantage is that external
    programs can be used to publish data to be consumed by clients.

    The ``channel`` keyword argument specifies which Redis channel to
    publish to and defaults to ``tornadose``.

    All remaining keyword arguments are passed directly to the
    ``redis.StrictRedis`` constructor. See `redis-py`__'s
    documentation for detais.

    New messages are read in a background thread via a
    :class:`concurrent.futures.ThreadPoolExecutor`.

    __ https://redis-py.readthedocs.org/en/latest/

    :raises ConnectionError: when the Redis host is not pingable

    """
    def initialize(self, channel='tornadose', **kwargs):
        if redis is None:
            raise RuntimeError("The redis module is required to use RedisStore")

        self.executor = ThreadPoolExecutor(max_workers=1)
        self.channel = channel
        self.messages = Queue()
        self._done = Event()

        self._redis = redis.StrictRedis(**kwargs)
        self._redis.ping()
        self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True)
        self._pubsub.subscribe(self.channel)

        self.publish()

    def submit(self, message, debug=False):
        self._redis.publish(self.channel, message)
        if debug:
            logger.debug(message)
            self._redis.setex(self.channel, 5, message)

    def shutdown(self):
        """Stop the publishing loop."""
        self._done.set()
        self.executor.shutdown(wait=False)

    def _get_message(self):
        data = self._pubsub.get_message(timeout=1)
        if data is not None:
            data = data['data']
        return data

    async def publish(self):
        loop = IOLoop.current()

        while not self._done.is_set():
            data = await loop.run_in_executor(self.executor,
                                              self._get_message)
            if len(self.subscribers) > 0 and data is not None:
                [subscriber.submit(data) for subscriber in self.subscribers]
예제 #27
0
파일: queue.py 프로젝트: nbashev/noc
 def _notify_waiter(self, waiter: asyncio.Event) -> None:
     if self.loop:
         self.loop.call_soon_threadsafe(waiter.set)
     else:
         waiter.set()
예제 #28
0
파일: slurm.py 프로젝트: FragMAX/FragMAXapp
 def __init__(self):
     self._next_timeout = 0
     self.restart_event = Event()
예제 #29
0
    async def fetch_event(self,
                          first_fetch_time: datetime,
                          initial_offset: int = 0,
                          event_type: str = '') -> AsyncGenerator[Dict, None]:
        """Retrieves events from a CrowdStrike Falcon stream starting from given offset.

        Args:
            first_fetch_time (datetime): The start time to fetch from retroactively for the first fetch.
            initial_offset (int): Stream offset to start the fetch from.
            event_type (str): Stream event type to fetch.

        Yields:
            AsyncGenerator[Dict, None]: Event fetched from the stream.
        """
        while True:
            demisto.debug('Fetching event')
            event = Event()
            create_task(self._discover_refresh_stream(event))
            demisto.debug('Waiting for stream discovery or refresh')
            await event.wait()
            demisto.debug('Done waiting for stream discovery or refresh')
            events_fetched = 0
            new_lines_fetched = 0
            last_fetch_stats_print = datetime.utcnow()
            async with ClientSession(
                    connector=TCPConnector(ssl=self.verify_ssl),
                    headers={
                        'Authorization': f'Token {self.session_token}',
                        'Connection': 'keep-alive'
                    },
                    trust_env=self.proxy,
                    timeout=None) as session:
                try:
                    integration_context = get_integration_context()
                    offset = integration_context.get('offset',
                                                     0) or initial_offset
                    demisto.debug(
                        f'Starting to fetch from offset {offset} events of type {event_type} '
                        f'from time {first_fetch_time}')
                    async with session.get(self.data_feed_url,
                                           params={
                                               'offset': offset,
                                               'eventType': event_type
                                           },
                                           timeout=None) as res:
                        demisto.debug(f'Fetched event: {res.content}')
                        async for line in res.content:
                            stripped_line = line.strip()
                            if stripped_line:
                                events_fetched += 1
                                try:
                                    streaming_event = json.loads(stripped_line)
                                    event_metadata = streaming_event.get(
                                        'metadata', {})
                                    event_creation_time = event_metadata.get(
                                        'eventCreationTime', 0)
                                    if not event_creation_time:
                                        demisto.debug(
                                            'Could not extract "eventCreationTime" field, using 0 instead. '
                                            f'{streaming_event}')
                                    else:
                                        event_creation_time /= 1000
                                    event_creation_time_dt = datetime.fromtimestamp(
                                        event_creation_time)
                                    if event_creation_time_dt < first_fetch_time:
                                        demisto.debug(
                                            f'Event with offset {event_metadata.get("offset")} '
                                            f'and creation time {event_creation_time} was skipped.'
                                        )
                                        continue
                                    yield streaming_event
                                except json.decoder.JSONDecodeError:
                                    demisto.debug(
                                        f'Failed decoding event (skipping it) - {str(stripped_line)}'
                                    )
                            else:
                                new_lines_fetched += 1
                            if last_fetch_stats_print + timedelta(
                                    minutes=1) <= datetime.utcnow():
                                demisto.info(
                                    f'Fetched {events_fetched} events and'
                                    f' {new_lines_fetched} new lines'
                                    f' from the stream in the last minute.')
                                events_fetched = 0
                                new_lines_fetched = 0
                                last_fetch_stats_print = datetime.utcnow()
                except Exception as e:
                    demisto.debug(
                        f'Failed to fetch event: {e} - Going to sleep for 10 seconds and then retry -'
                        f' {traceback.format_exc()}')
                    await sleep(10)
예제 #30
0
 def __init__(self, limit=None, *, loop):
     self._limit = limit
     self._current = 0
     self._loop = loop
     self._release = Event(loop=loop)
예제 #31
0
async def cancel_wrapper(stream: Stream[_TSend, _TRecv],
                         stop: asyncio.Event) -> AsyncIterator[_TRecv]:
    async for event in stop_wrapper(stream, stop):
        yield event
    if stop.is_set():
        await stream.cancel()
예제 #32
0
class Stream:
    """
    API for working with streams, used by clients and request handlers
    """
    id = None
    __buffer__ = None
    __wrapper__ = None

    def __init__(self,
                 connection: Connection,
                 h2_connection: H2Connection,
                 transport: Transport,
                 *,
                 loop: AbstractEventLoop,
                 stream_id: Optional[int] = None,
                 wrapper: Optional[Wrapper] = None) -> None:
        self._connection = connection
        self._h2_connection = h2_connection
        self._transport = transport
        self._loop = loop
        self.__wrapper__ = wrapper

        if stream_id is not None:
            self.id = stream_id
            self.__buffer__ = Buffer(self.id,
                                     self._connection,
                                     self._h2_connection,
                                     loop=self._loop)

        self.__headers__ = Queue(loop=loop) \
            # type: Queue[List[Tuple[str, str]]]

        self.__window_updated__ = Event(loop=loop)

    async def recv_headers(self):
        return await self.__headers__.get()

    def recv_headers_nowait(self):
        try:
            return self.__headers__.get_nowait()
        except QueueEmpty:
            return None

    async def recv_data(self, size):
        return await self.__buffer__.read(size)

    async def send_request(self, headers, end_stream=False, *, _processor):
        assert self.id is None, self.id
        while True:
            # this is the first thing we should check before even trying to
            # create new stream, because this wait() can be cancelled by timeout
            # and we wouldn't need to create new stream at all
            if not self._connection.write_ready.is_set():
                await self._connection.write_ready.wait()

            # `get_next_available_stream_id()` should be as close to
            # `connection.send_headers()` as possible, without any async
            # interruptions in between, see the docs on the
            # `get_next_available_stream_id()` method
            stream_id = self._h2_connection.get_next_available_stream_id()
            try:
                self._h2_connection.send_headers(stream_id,
                                                 headers,
                                                 end_stream=end_stream)
            except TooManyStreamsError:
                # we're going to wait until any of currently opened streams will
                # be closed, and we will be able to open a new one
                # TODO: maybe implement FIFO for waiters, but this limit
                #       shouldn't be reached in a normal case, so why bother
                # TODO: maybe we should raise an exception here instead of
                #       waiting, if timeout wasn't set for the current request
                self._connection.stream_close_waiter.clear()
                await self._connection.stream_close_waiter.wait()
                # while we were trying to create a new stream, write buffer
                # can became full, so we need to repeat checks from checking
                # if we can write() data
                continue
            else:
                self.id = stream_id
                self.__buffer__ = Buffer(self.id,
                                         self._connection,
                                         self._h2_connection,
                                         loop=self._loop)
                release_stream = _processor.register(self)
                self._transport.write(self._h2_connection.data_to_send())
                return release_stream

    async def send_headers(self, headers, end_stream=False):
        assert self.id is not None
        if not self._connection.write_ready.is_set():
            await self._connection.write_ready.wait()

        # Workaround for the H2Connection.send_headers method, which will try
        # to create a new stream if it was removed earlier from the
        # H2Connection.streams, and therefore will raise StreamIDTooLowError
        if self.id not in self._h2_connection.streams:
            raise StreamClosedError(self.id)

        self._h2_connection.send_headers(self.id,
                                         headers,
                                         end_stream=end_stream)
        self._transport.write(self._h2_connection.data_to_send())

    async def send_data(self, data, end_stream=False):
        f = BytesIO(data)
        f_pos, f_last = 0, len(data)

        while True:
            if not self._connection.write_ready.is_set():
                await self._connection.write_ready.wait()

            window = self._h2_connection.local_flow_control_window(self.id)
            if not window:
                self.__window_updated__.clear()
                await self.__window_updated__.wait()
                window = self._h2_connection.local_flow_control_window(self.id)

            max_frame_size = self._h2_connection.max_outbound_frame_size
            f_chunk = f.read(min(window, max_frame_size, f_last - f_pos))
            f_pos = f.tell()

            if f_pos == f_last:
                self._h2_connection.send_data(self.id,
                                              f_chunk,
                                              end_stream=end_stream)
                self._transport.write(self._h2_connection.data_to_send())
                break
            else:
                self._h2_connection.send_data(self.id, f_chunk)
                self._transport.write(self._h2_connection.data_to_send())

    async def end(self):
        if not self._connection.write_ready.is_set():
            await self._connection.write_ready.wait()
        self._h2_connection.end_stream(self.id)
        self._transport.write(self._h2_connection.data_to_send())

    async def reset(self, error_code=ErrorCodes.NO_ERROR):
        if not self._connection.write_ready.is_set():
            await self._connection.write_ready.wait()
        self._h2_connection.reset_stream(self.id, error_code=error_code)
        self._transport.write(self._h2_connection.data_to_send())

    def reset_nowait(self, error_code=ErrorCodes.NO_ERROR):
        self._h2_connection.reset_stream(self.id, error_code=error_code)
        if self._connection.write_ready.is_set():
            self._transport.write(self._h2_connection.data_to_send())

    def __ended__(self):
        self.__buffer__.eof()

    def __terminated__(self, reason):
        if self.__wrapper__ is not None:
            self.__wrapper__.cancel(StreamTerminatedError(reason))

    @property
    def closable(self):
        if self._h2_connection.state_machine.state is ConnectionState.CLOSED:
            return False
        stream = self._h2_connection.streams.get(self.id)
        if stream is None:
            return False
        return not stream.closed
예제 #33
0
class AsyncConnectionPool:
    """Object manages asynchronous connections.

    :param int size: size (number of connection) of the pool.
    :param float queue_timeout: time out when client is waiting connection
        from pool
    :param loop: event loop, if not passed then default will be used
    :param config: MySql connection config see
        `doc. <http://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html>`_
    :raise ValueError: if the `size` is inappropriate
    """
    def __init__(self, size=1, queue_timeout=15.0, *, loop=None, **config):
        assert size > 0, 'DBPool.size must be greater than 0'
        if size < 1:
            raise ValueError('DBPool.size is less than 1, '
                             'connections won"t be established')
        self._pool = set()
        self._busy_items = set()
        self._size = size
        self._pending_futures = deque()
        self._queue_timeout = queue_timeout
        self._loop = loop or asyncio.get_event_loop()
        self.config = config

        self._shutdown_event = Event(loop=self._loop)
        self._shutdown_event.set()

    @property
    def queue_timeout(self):
        """Number of seconds to wait a connection from the pool,
        before TimeoutError occurred

        :rtype: float
        """
        return self._queue_timeout

    @queue_timeout.setter
    def queue_timeout(self, value):
        """Sets a timeout for :attr:`queue_timeout`

        :param float value: number of seconds
        """
        if not isinstance(value, (float, int)):
            raise ValueError('Float or integer type expected')
        self._queue_timeout = value

    @property
    def size(self):
        """Size of pool

        :rtype: int
        """
        return self._size

    def __len__(self):
        """Number of allocated pool's slots

        :rtype: int
        """
        return len(self._pool)

    @property
    def free_count(self):
        """Number of free pool's slots

        :rtype: int
        """
        return self.size - len(self._busy_items)

    @asyncio.coroutine
    def get(self):
        """Coroutine. Returns an opened connection from pool.
        If coroutine invoked when all connections have been issued, then
        caller will blocked until some connection will be released.

        Also, the class provides context manager for getting connection
        and automatically freeing it. Example:
        >>> with (yield from pool) as cnx:
        >>>     ...

        :rtype: AsyncMySQLConnection
        :raise: concurrent.futures.TimeoutError()
        """
        cnx = None

        yield from self._shutdown_event.wait()

        for free_client in self._pool - self._busy_items:
            cnx = free_client
            self._busy_items.add(cnx)
            break
        else:
            if len(self) < self.size:
                cnx = AsyncMySQLConnection(loop=self._loop)
                self._pool.add(cnx)
                self._busy_items.add(cnx)

                try:
                    yield from cnx.connect(**self.config)
                except:
                    self._pool.remove(cnx)
                    self._busy_items.remove(cnx)
                    raise

        if not cnx:
            queue_future = Future(loop=self._loop)
            self._pending_futures.append(queue_future)
            try:
                cnx = yield from asyncio.wait_for(queue_future,
                                                  self.queue_timeout,
                                                  loop=self._loop)
                self._busy_items.add(cnx)
            except TimeoutError:
                raise TimeoutError('Database pool is busy')
            finally:
                try:
                    self._pending_futures.remove(queue_future)
                except ValueError:
                    pass

        return cnx

    def release(self, connection):
        """Frees connection. After that the connection can be issued
        by :func:`get`.

        :param AsyncMySQLConnection connection: a connection received
            from :func:`get`
        """
        if len(self._pending_futures):
            f = self._pending_futures.popleft()
            f.set_result(connection)
        else:
            self._busy_items.remove(connection)

    @asyncio.coroutine
    def shutdown(self):
        """Coroutine. Closes all connections and purge queue of a waiting
        for connection.
        """
        self._shutdown_event.clear()
        try:
            for cnx in self._pool:
                yield from cnx.disconnect()

            for f in self._pending_futures:
                f.cancel()

            self._pending_futures.clear()
            self._pool = set()
            self._busy_items = set()
        finally:
            self._shutdown_event.set()

    def __enter__(self):
        raise RuntimeError(
            '"yield from" should be used as context manager expression')

    def __exit__(self, *args):
        # This must exist because __enter__ exists, even though that
        # always raises; that's how the with-statement works.
        pass

    @asyncio.coroutine
    def __iter__(self):
        cnx = yield from self.get()
        return ContextManager(self, cnx)
예제 #34
0
async def update_lifetime(ev: asyncio.Event, lifetime: int):
    i = 0
    while not ev.is_set():
        logger.info("update lifetime:%d %d", i, lifetime + 20)  # xxx: +20?
        await asyncio.sleep(1)
        i += 1
예제 #35
0
 async def wait_for_ready_complete(self, guild: Guild):
     if guild.id not in self.ready_locks:
         self.ready_locks[guild.id] = Event()
     await self.ready_locks[guild.id].wait()
예제 #36
0
 def __init__(self, totals: Dict[str, int]) -> None:
     self.state = {
         key: ProgressBarState(0, totals[key])
         for key in totals.keys()
     }
     self.event = Event()
예제 #37
0
class NabIO(object, metaclass=abc.ABCMeta):
    """ Interface for I/O interactions with a nabaztag """

    # https://github.com/nabaztag2018/hardware/blob/master/RPI_Nabaztag.PDF
    MODEL_2018 = 1
    # https://github.com/nabaztag2018/hardware/blob/master/
    # pyNab_V4.1_voice_reco.PDF
    MODEL_2019_TAG = 2
    # with RFID
    MODEL_2019_TAGTAG = 3

    # Each info loop lasts 15 seconds
    INFO_LOOP_LENGTH = 15.0

    def __init__(self):
        super().__init__()
        self.cancel_event = Event()

    @abc.abstractmethod
    async def setup_ears(self, left_ear, right_ear):
        """
        Init ears and move them to the initial position.
        """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    async def move_ears(self, left_ear, right_ear):
        """
        Move ears to a given position and return only when they reached this
        position.
        """
        raise NotImplementedError("Should have implemented")

    async def move_ears_with_leds(self, color, new_left, new_right):
        """
        If ears are not in given position, set LEDs to given color, move ears,
        turn LEDs off and return.
        """
        do_move = False
        current_left, current_right = await self.ears.get_positions()
        if current_left != new_left:
            if not self.ears.is_broken(Ears.LEFT_EAR):
                do_move = True
        if current_right != new_right:
            if not self.ears.is_broken(Ears.RIGHT_EAR):
                do_move = True
        if do_move:
            self.set_leds(color, color, color, color, color)
            await self.move_ears(new_left, new_right)
        self.set_leds(None, None, None, None, None)

    @abc.abstractmethod
    async def detect_ears_positions(self):
        """
        Detect ears positions and return the position before the detection.
        A second call will return the current position.
        """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def set_leds(self, nose, left, center, right, bottom):
        """ Set the leds. None means to turn them off. """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def pulse(self, led, color):
        """ Set a led to pulse. """
        raise NotImplementedError("Should have implemented")

    async def rfid_detected_feedback(self):
        ci = ChoreographyInterpreter(self.leds, self.ears, self.sound)
        await ci.start("nabd/rfid.chor")
        await self.sound.play_list(["rfid/rfid.wav"], False)
        await ci.stop()
        self.set_leds(None, None, None, None, None)

    def rfid_awaiting_feedback(self):
        """
        Turn nose red.
        """
        self.set_leds((255, 0, 255), (0, 0, 0), (0, 0, 0), (0, 0, 0),
                      (0, 0, 0))

    def rfid_done_feedback(self):
        """
        Turn everything off.
        """
        self.set_leds((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))

    @abc.abstractmethod
    def bind_button_event(self, loop, callback):
        """
        Define the callback for button events.
        callback is cb(event_type, time) with event_type being:
        - 'down'
        - 'up'
        - 'long_down'
        - 'double_click'
        - 'click_and_hold'

        Make sure the callback is called on the provided event loop, with
        loop.call_soon_threadsafe
        """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def bind_ears_event(self, loop, callback):
        """
        Define the callback for ears events.
        callback is cb(ear) ear being the ear moved.

        Make sure the callback is called on the provided event loop, with
        loop.call_soon_threadsafe
        """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def bind_rfid_event(self, loop, callback):
        """
        Define the callback for rfid events.
        callback is cb(uid, picture, app, data, flags)

        Make sure the callback is called on the provided event loop, with
        loop.call_soon_threadsafe
        """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    async def play_info(self, condvar, tempo, colors):
        """
        Play an info animation.
        tempo & colors are as described in the nabd protocol.
        Run the animation in loop for the complete info duration (15 seconds)
        or until condvar is notified

        If 'left'/'center'/'right' slots are absent, the light is off.
        Return true if condvar was notified
        """
        raise NotImplementedError("Should have implemented")

    async def start_acquisition(self, acquisition_cb):
        """
        Play listen sound and start acquisition, calling callback with sound
        samples.
        """
        self.set_leds((255, 0, 255), (0, 0, 0), (0, 0, 0), (0, 0, 0),
                      (0, 0, 0))
        await self.sound.play_list(["asr/listen.mp3"], False)
        await self.sound.start_recording(acquisition_cb)

    async def end_acquisition(self):
        """
        Play acquired sound and call callback with finalize.
        """
        await self.sound.stop_recording()
        await self.sound.play_list(["asr/acquired.mp3"], False)

    async def asr_failed(self):
        """
        Feedback when ASR or NLU failed.
        """
        await self.sound.play_list(["asr/failed/*.mp3"], False)

    async def play_message(self, signature, body):
        """
        Play a message, i.e. a signature, a body and a signature.
        """
        self.cancel_event.clear()
        # Turn leds red while ears go to 0, 0
        await self.move_ears_with_leds((255, 0, 0), 0, 0)
        preloaded_sig = await self._preload([signature])
        preloaded_body = await self._preload(body)
        ci = ChoreographyInterpreter(self.leds, self.ears, self.sound)
        await self._play_preloaded(ci, preloaded_sig,
                                   ChoreographyInterpreter.STREAMING_URN)
        await self._play_preloaded(ci, preloaded_body,
                                   ChoreographyInterpreter.STREAMING_URN)
        await self._play_preloaded(ci, preloaded_sig,
                                   ChoreographyInterpreter.STREAMING_URN)
        await ci.stop()
        self.set_leds(None, None, None, None, None)

    async def play_sequence(self, sequence):
        """
        Play a simple sequence
        """
        self.cancel_event.clear()
        preloaded = await self._preload(sequence)
        ci = ChoreographyInterpreter(self.leds, self.ears, self.sound)
        await self._play_preloaded(ci, preloaded, None)

    async def _play_preloaded(self, ci, preloaded, default_chor):
        for seq_item in preloaded:
            if self.cancel_event.is_set():
                break
            if "choreography" in seq_item:
                chor = seq_item["choreography"]
            else:
                chor = default_chor
            if chor is not None:
                await ci.start(chor)
            else:
                await ci.stop()
            if "audio" in seq_item:
                await self.sound.play_list(seq_item["audio"], True,
                                           self.cancel_event)
                if chor is not None:
                    await ci.stop()
            elif "choreography" in seq_item:
                await ci.wait_until_complete(self.cancel_event)

    async def _preload(self, sequence):
        preloaded_sequence = []
        for seq_item in sequence:
            if self.cancel_event.is_set():
                break
            if "audio" in seq_item:
                preloaded_audio_list = []
                if isinstance(seq_item["audio"], str):
                    print(f"Warning: audio should be a list of resources "
                          f"(sequence item: {seq_item})")
                    audio_list = [seq_item["audio"]]
                else:
                    audio_list = seq_item["audio"]
                for res in audio_list:
                    f = await self.sound.preload(res)
                    if f is not None:
                        preloaded_audio_list.append(f)
                seq_item["audio"] = preloaded_audio_list
            preloaded_sequence.append(seq_item)
        return preloaded_sequence

    async def cancel(self, feedback=False):
        """
        Cancel currently running sequence or info animation.
        """
        self.cancel_event.set()
        if feedback:
            await self.sound.play_list(["nabd/abort.wav"], False)

    @abc.abstractmethod
    async def gestalt(self):
        """ Return a structure representing hardware info. """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def has_sound_input(self):
        """ Determine if we have sound input """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    def has_rfid(self):
        """ Determine if we have an rfid reader """
        raise NotImplementedError("Should have implemented")

    @abc.abstractmethod
    async def test(self, test):
        """ Run a given hardware test, returning True if everything is ok """
        raise NotImplementedError("Should have implemented")
예제 #38
0

if __name__ == '__main__':
    log_dir = os.path.join(os.path.dirname(__file__), './logs')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    logger = logging.getLogger("loader")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s  %(levelname)s  %(filename)s  %(lineno)s  %(message)s')
    log_file = os.path.join(log_dir, 'updater.log')
    handler = RotatingFileHandler(
        log_file, maxBytes=100 * 1024 * 1024, backupCount=5)
    stream = logging.StreamHandler()
    stream.setFormatter(formatter)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(stream)

    ex_ev = Event()
    app = Loader(ex_ev, logger)


    def signal_handler(sig, frame):
        app.exit_event.set()


    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    app.run()
예제 #39
0
파일: recovery.py 프로젝트: smaxtec/faust
class Recovery(Service):
    """Service responsible for recovering tables from changelog topics."""

    app: AppT

    tables: _TableManager

    stats_interval: float = 5.0

    #: Set of standby topic partitions.
    standby_tps: Set[TP]

    #: Set of active topic partitions.
    active_tps: Set[TP]

    actives_for_table: MutableMapping[CollectionT, Set[TP]]
    standbys_for_table: MutableMapping[CollectionT, Set[TP]]

    #: Mapping from topic partition to table
    tp_to_table: MutableMapping[TP, CollectionT]

    #: Active offset by topic partition.
    active_offsets: Counter[TP]

    #: Standby offset by topic partition.
    standby_offsets: Counter[TP]

    #: Mapping of highwaters by topic partition.
    highwaters: Counter[TP]

    #: Active highwaters by topic partition.
    active_highwaters: Counter[TP]

    #: Standby highwaters by topic partition.
    standby_highwaters: Counter[TP]

    _signal_recovery_start: Optional[Event] = None
    _signal_recovery_end: Optional[Event] = None

    completed: Event
    in_recovery: bool = False
    standbys_pending: bool = False
    recovery_delay: float

    #: Changelog event buffers by table.
    #: These are filled by background task `_slurp_changelog`,
    #: and need to be flushed before starting new recovery/stopping.
    buffers: MutableMapping[CollectionT, List[EventT]]

    #: Cache of max buffer size by topic partition..
    buffer_sizes: MutableMapping[TP, int]

    #: Time in seconds after we warn that no flush has happened.
    flush_timeout_secs: float = 120.0

    #: Time in seconds after we warn that no events have been received.
    event_timeout_secs: float = 30.0

    #: Time of last event received by active TP
    _active_events_received_at: MutableMapping[TP, float]

    #: Time of last event received by standby TP
    _standby_events_received_at: MutableMapping[TP, float]

    #: Time of last event received (for any active TP)
    _last_active_event_processed_at: Optional[float]

    #: Time of last buffer flush
    _last_flush_at: Optional[float] = None

    #: Time when recovery last started
    _recovery_started_at: Optional[float] = None

    #: Time when recovery last ended
    _recovery_ended_at: Optional[float] = None

    _recovery_span: Optional[opentracing.Span] = None
    _actives_span: Optional[opentracing.Span] = None
    _standbys_span: Optional[opentracing.Span] = None

    #: List of last 100 processing timestamps (monotonic).
    #: Updated after processing every changelog record,
    #: used to estimate time remaining.
    _processing_times: Deque[float]

    #: Number of entries in _processing_times before
    #: we can give an estimate of time remaining.
    num_samples_required_for_estimate = 1000
    _generation_id: int = 0

    def __init__(self, app: AppT, tables: TableManagerT,
                 **kwargs: Any) -> None:
        self.app = app
        self.tables = cast(_TableManager, tables)

        self.standby_tps = set()
        self.active_tps = set()

        self.tp_to_table = {}
        self.active_offsets = Counter()
        self.standby_offsets = Counter()

        self.active_highwaters = Counter()
        self.standby_highwaters = Counter()
        self.completed = Event()

        self.buffers = defaultdict(list)
        self.buffer_sizes = {}
        self.recovery_delay = self.app.conf.stream_recovery_delay

        self.actives_for_table = defaultdict(set)
        self.standbys_for_table = defaultdict(set)

        self._active_events_received_at = {}
        self._standby_events_received_at = {}
        self._processing_times = deque()

        super().__init__(**kwargs)

    @property
    def signal_recovery_start(self) -> Event:
        """Event used to signal that recovery has started."""
        if self._signal_recovery_start is None:
            self._signal_recovery_start = Event(loop=self.loop)
        return self._signal_recovery_start

    @property
    def signal_recovery_end(self) -> Event:
        """Event used to signal that recovery has ended."""
        if self._signal_recovery_end is None:
            self._signal_recovery_end = Event(loop=self.loop)
        return self._signal_recovery_end

    async def on_stop(self) -> None:
        """Call when recovery service stops."""
        # Flush buffers when stopping.
        self.flush_buffers()

    def add_active(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for active recovery."""
        self.active_tps.add(tp)
        self.actives_for_table[table].add(tp)
        self._add(table, tp, self.active_offsets)

    def add_standby(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for standby recovery."""
        self.standby_tps.add(tp)
        self.standbys_for_table[table].add(tp)
        self._add(table, tp, self.standby_offsets)

    def _add(self, table: CollectionT, tp: TP, offsets: Counter[TP]) -> None:
        self.tp_to_table[tp] = table
        persisted_offset = table.persisted_offset(tp)
        if persisted_offset is not None:
            offsets[tp] = persisted_offset
        offsets.setdefault(tp, None)  # type: ignore

    def revoke(self, tp: TP) -> None:
        """Revoke assignment of table changelog partition."""
        self.standby_offsets.pop(tp, None)
        self.standby_highwaters.pop(tp, None)
        self.active_offsets.pop(tp, None)
        self.active_highwaters.pop(tp, None)

    def on_partitions_revoked(self, revoked: Set[TP]) -> None:
        """Call when rebalancing and partitions are revoked."""
        T = traced_from_parent_span()
        T(self.flush_buffers)()

    async def on_rebalance(
        self,
        assigned: Set[TP],
        revoked: Set[TP],
        newly_assigned: Set[TP],
        generation_id: int = 0,
    ) -> None:
        """Call when cluster is rebalancing."""
        # removing all the sleeps so control does not go back to the loop
        app = self.app
        logger.info(f"generation id {generation_id} "
                    f"app consumers id {app.consumer_generation_id}")

        if generation_id != app.consumer_generation_id:
            logger.warning(
                f"rebalance again generation id "
                f"{generation_id} app consumers id {app.consumer_generation_id}"
            )
            return
        assigned_standbys = app.assignor.assigned_standbys()
        assigned_actives = app.assignor.assigned_actives()

        for tp in revoked:
            self.revoke(tp)

        self.standby_tps.clear()
        self.active_tps.clear()
        self.actives_for_table.clear()
        self.standbys_for_table.clear()

        for tp in assigned_standbys:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_standby(table, tp)
        for tp in assigned_actives:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_active(table, tp)

        active_offsets = {
            tp: offset
            for tp, offset in self.active_offsets.items()
            if tp in self.active_tps
        }
        self.active_offsets.clear()
        self.active_offsets.update(active_offsets)

        rebalancing_span = cast(_App, self.app)._rebalancing_span
        if app.tracer and rebalancing_span:
            self._recovery_span = app.tracer.get_tracer("_faust").start_span(
                "recovery",
                child_of=rebalancing_span,
            )
            app._span_add_default_tags(self._recovery_span)
        self.signal_recovery_start.set()
        self._generation_id = generation_id

    async def _resume_streams(self, generation_id: int = 0) -> None:
        app = self.app
        consumer = app.consumer
        await app.on_rebalance_complete.send()
        assignment = consumer.assignment()
        if self.app.consumer_generation_id != generation_id:
            self.log.warning("Recovery rebalancing again")
            return
        if assignment:
            self.log.dev("Resume stream partitions")
            consumer.resume_partitions(
                {tp
                 for tp in assignment if not self._is_changelog_tp(tp)})
            self.log.info("Seek stream partitions to committed offsets.")
            await self._wait(consumer.perform_seek(),
                             timeout=self.app.conf.broker_request_timeout)
        else:
            self.log.info("Resuming streams with empty assignment")
        self.completed.set()
        # Resume partitions and start fetching.
        self.log.info("Resuming flow...")
        app.flow_control.resume()
        consumer.resume_flow()
        # finally make sure the fetcher is running.
        await cast(_App, app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        self.tables.on_standbys_ready()
        app.on_rebalance_end()
        self.log.info("Worker ready")

    @Service.task
    async def _restart_recovery(self) -> None:
        consumer = self.app.consumer
        active_tps = self.active_tps
        standby_tps = self.standby_tps
        standby_offsets = self.standby_offsets
        standby_highwaters = self.standby_highwaters
        assigned_active_tps = self.active_tps
        assigned_standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_highwaters = self.active_highwaters
        while not self.should_stop:
            self.log.dev("WAITING FOR NEXT RECOVERY TO START")
            if await self.wait_for_stopped(self.signal_recovery_start):
                self.signal_recovery_start.clear()
                break  # service was stopped
            self.signal_recovery_start.clear()
            generation_id = self._generation_id
            span: Any = None
            spans: list = []
            tracer: Optional[opentracing.Tracer] = None
            if self.app.tracer:
                tracer = self.app.tracer.get_tracer("_faust")
            if tracer is not None and self._recovery_span:
                span = tracer.start_span("recovery-thread",
                                         child_of=self._recovery_span)
                self.app._span_add_default_tags(span)
                spans.extend([span, self._recovery_span])
            T = traced_from_parent_span(span)

            try:
                await self._wait(T(asyncio.sleep)(self.recovery_delay))

                if not self.tables or self.app.conf.store == URL("aerospike:"):
                    # If there are no tables -- simply resume streams
                    await T(self._resume_streams)(generation_id=generation_id)
                    for _span in spans:
                        finish_span(_span)
                    continue

                self._set_recovery_started()
                self.standbys_pending = True
                # Must flush any buffers before starting rebalance.
                T(self.flush_buffers)()
                producer = cast(_App, self.app)._producer
                if producer is not None:
                    await self._wait(
                        T(producer.flush)(),
                        timeout=self.app.conf.broker_request_timeout,
                    )

                self.log.dev("Build highwaters for active partitions")
                await self._wait(
                    T(self._build_highwaters)(consumer, assigned_active_tps,
                                              active_highwaters, "active"),
                    timeout=self.app.conf.broker_request_timeout,
                )

                self.log.dev("Build offsets for active partitions")
                await self._wait(
                    T(self._build_offsets)(consumer, assigned_active_tps,
                                           active_offsets, "active"),
                    timeout=self.app.conf.broker_request_timeout,
                )
                if self.app.conf.recovery_consistency_check:
                    for tp in assigned_active_tps:
                        if (active_offsets[tp] and active_highwaters[tp] and
                                active_offsets[tp] > active_highwaters[tp]):
                            raise ConsistencyError(
                                E_PERSISTED_OFFSET.format(
                                    tp,
                                    active_offsets[tp],
                                    active_highwaters[tp],
                                ), )

                self.log.dev("Build offsets for standby partitions")
                await self._wait(
                    T(self._build_offsets)(consumer, assigned_standby_tps,
                                           standby_offsets, "standby"),
                    timeout=self.app.conf.broker_request_timeout,
                )

                self.log.dev("Seek offsets for active partitions")
                await self._wait(
                    T(self._seek_offsets)(consumer, assigned_active_tps,
                                          active_offsets, "active"),
                    timeout=self.app.conf.broker_request_timeout,
                )
                if self.signal_recovery_start.is_set():
                    logger.info("Restarting Recovery")
                    continue

                if self.need_recovery():
                    self._set_recovery_started()
                    self.standbys_pending = True
                    self.log.info("Restoring state from changelog topics...")
                    T(consumer.resume_partitions)(active_tps)
                    # Resume partitions and start fetching.
                    self.log.info("Resuming flow...")
                    T(self.app.flow_control.resume)()
                    T(consumer.resume_flow)()
                    await T(cast(_App, self.app)._fetcher.maybe_start)()

                    # Wait for actives to be up to date.
                    # This signal will be set by _slurp_changelogs
                    if tracer is not None and span:
                        self._actives_span = tracer.start_span(
                            "recovery-actives",
                            child_of=span,
                            tags={"Active-Stats": self.active_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    try:
                        await self._wait(self.signal_recovery_end.wait())
                    except Exception as exc:
                        finish_span(self._actives_span, error=exc)
                    else:
                        finish_span(self._actives_span)
                    finally:
                        self._actives_span = None

                    # recovery done.
                    self.log.info("Done reading from changelog topics")
                    T(consumer.pause_partitions)(active_tps)
                else:
                    self.log.info("Resuming flow...")
                    T(self.app.flow_control.resume)()
                    T(consumer.resume_flow)()
                    self._set_recovery_ended()
                self.log.info("Recovery complete")
                if span:
                    span.set_tag("Recovery-Completed", True)

                if standby_tps:
                    self.log.info("Starting standby partitions...")

                    self.log.dev("Seek standby offsets")
                    await self._wait(
                        T(self._seek_offsets)(consumer, standby_tps,
                                              standby_offsets, "standby"),
                        timeout=self.app.conf.broker_request_timeout,
                    )

                    self.log.dev("Build standby highwaters")
                    await self._wait(
                        T(self._build_highwaters)(
                            consumer,
                            standby_tps,
                            standby_highwaters,
                            "standby",
                        ),
                        timeout=self.app.conf.broker_request_timeout,
                    )
                    if self.app.conf.recovery_consistency_check:
                        for tp in standby_tps:
                            if (standby_offsets[tp] and standby_highwaters[tp]
                                    and standby_offsets[tp] >
                                    standby_highwaters[tp]):
                                raise ConsistencyError(
                                    E_PERSISTED_OFFSET.format(
                                        tp,
                                        standby_offsets[tp],
                                        standby_highwaters[tp],
                                    ), )

                    if tracer is not None and span:
                        self._standbys_span = tracer.start_span(
                            "recovery-standbys",
                            child_of=span,
                            tags={"Standby-Stats": self.standby_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    self.log.dev("Resume standby partitions")
                    T(consumer.resume_partitions)(standby_tps)
                    T(self.app.flow_control.resume)()
                    T(consumer.resume_flow)()

                # Pause all our topic partitions,
                # to make sure we don't fetch any more records from them.
                await self._wait(T(self.on_recovery_completed)(generation_id))
            except RebalanceAgain as exc:
                self.log.dev("RAISED REBALANCE AGAIN")
                for _span in spans:
                    finish_span(_span, error=exc)
                continue  # another rebalance started
            except IllegalStateError as exc:
                self.log.dev("RAISED REBALANCE AGAIN")
                for _span in spans:
                    finish_span(_span, error=exc)
                continue  # another rebalance started
            except ServiceStopped as exc:
                self.log.dev("RAISED SERVICE STOPPED")
                for _span in spans:
                    finish_span(_span, error=exc)
                break  # service was stopped
            except Exception as exc:
                for _span in spans:
                    finish_span(_span, error=exc)
                raise
            else:
                for _span in spans:
                    finish_span(_span)
            # restart - wait for next rebalance.

    def _set_recovery_started(self) -> None:
        self.in_recovery = True
        self.app.in_recovery = True
        self._recovery_ended = None
        self._recovery_started_at = monotonic()
        self._active_events_received_at.clear()
        self._standby_events_received_at.clear()
        self._processing_times.clear()
        self._last_active_event_processed_at = None

    def _set_recovery_ended(self) -> None:
        self.in_recovery = False
        self.app.in_recovery = False
        self._recovery_ended_at = monotonic()
        self._active_events_received_at.clear()
        self._standby_events_received_at.clear()
        self._processing_times.clear()
        self._last_active_event_processed_at = None

    def active_remaining_seconds(self, remaining: float) -> str:
        s = self._estimated_active_remaining_secs(remaining)
        return humanize_seconds(s, now="none") if s else "???"

    def _estimated_active_remaining_secs(self,
                                         remaining: float) -> Optional[float]:
        processing_times = self._processing_times
        if len(processing_times) >= self.num_samples_required_for_estimate:
            mean_time = statistics.mean(processing_times)
            return (mean_time * remaining) * 1.10  # add 10%
        else:
            return None

    async def _wait(self,
                    coro: WaitArgT,
                    timeout: Optional[int] = None) -> None:
        signal = self.signal_recovery_start
        wait_result = await self.wait_first(coro, signal, timeout=timeout)
        if wait_result.stopped:
            # service was stopped.
            raise ServiceStopped()
        elif self.signal_recovery_start in wait_result.done:
            # another rebalance started
            raise RebalanceAgain()

        return None

    async def on_recovery_completed(self, generation_id: int = 0) -> None:
        """Call when active table recovery is completed."""
        consumer = self.app.consumer
        self.log.info("Restore complete!")
        await self.app.on_rebalance_complete.send()
        self._set_recovery_ended()
        # This needs to happen if all goes well
        callback_coros = [
            table.on_recovery_completed(
                self.actives_for_table[table],
                self.standbys_for_table[table],
            ) for table in self.tables.values()
        ]
        if callback_coros:
            await asyncio.wait(callback_coros)
        assignment = consumer.assignment()
        if self.app.consumer_generation_id != generation_id:
            self.log.warning(
                f"Recovery rebalancing again app id "
                f"{self.app.consumer_generation_id} param {generation_id}")
            return
        consumer.resume_partitions(
            {tp
             for tp in assignment if not self._is_changelog_tp(tp)})
        if assignment:
            self.log.info("Seek stream partitions to committed offsets.")
            await self._wait(consumer.perform_seek(),
                             timeout=self.app.conf.broker_request_timeout)
        self.completed.set()
        self.log.dev("Resume stream partitions")
        self.app.flow_control.resume()
        consumer.resume_flow()
        # finally make sure the fetcher is running.
        await cast(_App, self.app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        if not self.app.assignor.assigned_standbys():
            self.tables.on_standbys_ready()
        self.app.on_rebalance_end()
        self.log.info("Worker ready")

    async def _build_highwaters(self, consumer: ConsumerT, tps: Set[TP],
                                destination: Counter[TP], title: str) -> None:
        # -- Build highwater
        highwaters = await consumer.highwaters(*tps)
        highwaters = {
            # FIXME the -1 here is because of the way we commit offsets
            tp: value - 1 if value is not None else -1
            for tp, value in highwaters.items()
        }
        self.log.info(
            "Highwater for %s changelog partitions:\n%s",
            title,
            self._highwater_logtable(highwaters, title=title),
        )
        destination.clear()
        destination.update(highwaters)

    def _highwater_logtable(self, highwaters: Mapping[TP, int], *,
                            title: str) -> str:
        table_data = [[k.topic, str(k.partition),
                       str(v)] for k, v in sorted(highwaters.items())]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=f"Highwater - {title.capitalize()}",
            headers=["topic", "partition", "highwater"],
        )

    def _consolidate_table_keys(self, data: TableDataT) -> Iterator[List[str]]:
        """Format terminal log table to reduce noise from duplicate keys.

        We log tables where the first row is the name of the topic,
        and it gets noisy when that name is repeated over and over.

        This function replaces repeating topic names
        with the ditto mark.

        Note:
            Data must be sorted.
        """
        prev_key: Optional[str] = None
        for key, *rest in data:
            if prev_key is not None and prev_key == key:
                yield ["〃", *rest]  # ditto
            else:
                yield [key, *rest]
            prev_key = key

    async def _build_offsets(self, consumer: ConsumerT, tps: Set[TP],
                             destination: Counter[TP], title: str) -> None:
        # -- Update offsets
        # Offsets may have been compacted, need to get to the recent ones
        earliest = await consumer.earliest_offsets(*tps)
        # FIXME To be consistent with the offset -1 logic
        earliest = {tp: offset - 1 for tp, offset in earliest.items()}
        for tp in tps:
            last_value = destination[tp]
            new_value = earliest[tp]

            if last_value is None:
                destination[tp] = new_value
            elif new_value is None:
                destination[tp] = last_value
            else:
                destination[tp] = max(last_value, new_value)

        if destination:
            self.log.info(
                "%s offsets at start of reading:\n%s",
                title,
                self._start_offsets_logtable(destination, title=title),
            )

    def _start_offsets_logtable(self, offsets: Mapping[TP, int], *,
                                title: str) -> str:
        table_data = [[k.topic, str(k.partition),
                       str(v)] for k, v in sorted(offsets.items())]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=f"Reading Starts At - {title.capitalize()}",
            headers=["topic", "partition", "offset"],
        )

    async def _seek_offsets(self, consumer: ConsumerT, tps: Set[TP],
                            offsets: Counter[TP], title: str) -> None:
        # Seek to new offsets
        new_offsets = {}
        for tp in tps:
            offset = offsets[tp]
            if offset == -1:
                offset = 0
            new_offsets[tp] = offset
        # FIXME Remove check when fixed offset-1 discrepancy
        await consumer.seek_wait(new_offsets)

    @Service.task
    async def _slurp_changelogs(self) -> None:
        changelog_queue = self.tables.changelog_queue
        tp_to_table = self.tp_to_table

        active_tps = self.active_tps
        standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_events_received_at = self._active_events_received_at
        standby_events_received_at = self._standby_events_received_at

        buffers = self.buffers
        buffer_sizes = self.buffer_sizes
        processing_times = self._processing_times

        async def _maybe_signal_recovery_end(timeout=False,
                                             timeout_count=0) -> None:
            # lets wait at least 2 consecutive cycles for the queue to be
            # empty to avoid race conditions between
            # the aiokafka consumer position and draining of the queue
            if timeout and self.app.in_transaction and timeout_count > 1:
                await detect_aborted_tx()
            if not self.need_recovery() and self.in_recovery:
                # apply anything stuck in the buffers
                self.flush_buffers()
                self._set_recovery_ended()
                if self._actives_span is not None:
                    self._actives_span.set_tag("Actives-Ready", True)
                logger.debug("Setting recovery end")
                self.signal_recovery_end.set()

        async def detect_aborted_tx():
            highwaters = self.active_highwaters
            offsets = self.active_offsets
            for tp, highwater in highwaters.items():
                if (highwater is not None and offsets[tp] is not None
                        and offsets[tp] < highwater):
                    if await self.app.consumer.position(tp) >= highwater:
                        logger.info(f"Aborted tx until highwater for {tp}")
                        offsets[tp] = highwater

        timeout_count = 0
        while not self.should_stop:
            try:
                self.signal_recovery_end.clear()
                try:
                    event: EventT = await asyncio.wait_for(
                        changelog_queue.get(), timeout=5.0)
                except asyncio.TimeoutError:
                    timeout_count += 1
                    if self.should_stop:
                        return
                    await _maybe_signal_recovery_end(
                        timeout=True, timeout_count=timeout_count)
                    continue
                now = monotonic()
                timeout_count = 0
                message = event.message
                tp = message.tp
                offset = message.offset
                logger.debug(f"Recovery message topic {tp} offset {offset}")
                offsets: Counter[TP]
                bufsize = buffer_sizes.get(tp)
                is_active = False
                if tp in active_tps:
                    is_active = True
                    table = tp_to_table[tp]
                    offsets = active_offsets
                    if bufsize is None:
                        bufsize = buffer_sizes[tp] = table.recovery_buffer_size
                    active_events_received_at[tp] = now
                elif tp in standby_tps:
                    table = tp_to_table[tp]
                    offsets = standby_offsets
                    if bufsize is None:
                        bufsize = buffer_sizes[tp] = table.standby_buffer_size
                        standby_events_received_at[tp] = now
                else:
                    logger.warning(
                        f"recovery unknown topic {tp} offset {offset}")

                seen_offset = offsets.get(tp, None)
                logger.debug(
                    f"seen offset for {tp} is {seen_offset} message offset {offset}"
                )
                if seen_offset is None or offset > seen_offset:
                    offsets[tp] = offset
                    buf = buffers[table]
                    buf.append(event)
                    await table.on_changelog_event(event)
                    if len(buf) >= bufsize:
                        table.apply_changelog_batch(buf)
                        buf.clear()
                        self._last_flush_at = now
                    now_after = monotonic()

                    if is_active:
                        last_processed_at = self._last_active_event_processed_at
                        if last_processed_at is not None:
                            processing_times.append(now_after -
                                                    last_processed_at)
                            max_samples = self.num_samples_required_for_estimate
                            if len(processing_times) > max_samples:
                                processing_times.popleft()
                        self._last_active_event_processed_at = now_after

                await _maybe_signal_recovery_end()

                if not self.standby_remaining_total():
                    logger.debug("Completed standby partition fetch")
                    if self._standbys_span:
                        finish_span(self._standbys_span)
                        self._standbys_span = None
                    self.tables.on_standbys_ready()
            except Exception as ex:
                logger.warning(f"Error in recovery {ex}")

    def flush_buffers(self) -> None:
        """Flush changelog buffers."""
        for table, buffer in self.buffers.items():
            table.apply_changelog_batch(buffer)
            buffer.clear()
        self._last_flush_at = monotonic()

    def need_recovery(self) -> bool:
        """Return :const:`True` if recovery is required."""
        return any(v > 0 for v in self.active_remaining().values())

    def active_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by active partition."""
        highwaters = self.active_highwaters
        offsets = self.active_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater is not None and offsets[tp] is not None
        })

    def standby_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by standby partition."""
        highwaters = self.standby_highwaters
        offsets = self.standby_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater >= 0 and offsets[tp] >= 0
        })

    def active_remaining_total(self) -> int:
        """Return number of changes remaining for actives to be up-to-date."""
        return sum(self.active_remaining().values())

    def standby_remaining_total(self) -> int:
        """Return number of changes remaining for standbys to be up-to-date."""
        return sum(self.standby_remaining().values())

    def active_stats(self) -> RecoveryStatsMapping:
        """Return current active recovery statistics."""
        offsets = self.active_offsets
        return {
            tp: RecoveryStats(highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.active_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0
        }

    def standby_stats(self) -> RecoveryStatsMapping:
        """Return current standby recovery statistics."""
        offsets = self.standby_offsets
        return {
            tp: RecoveryStats(highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.standby_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0
        }

    def _stats_to_logtable(self, title: str,
                           stats: RecoveryStatsMapping) -> str:
        table_data = [
            list(
                map(
                    str,
                    [
                        tp.topic,
                        tp.partition,
                        s.highwater,
                        s.offset,
                        s.remaining,
                    ],
                )) for tp, s in sorted(stats.items())
        ]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=title,
            headers=[
                "topic",
                "partition",
                "need offset",
                "have offset",
                "remaining",
            ],
        )

    @Service.task
    async def _publish_stats(self) -> None:
        """Emit stats (remaining to fetch) while in active recovery."""
        interval = self.stats_interval
        await self.sleep(interval)
        async for sleep_time in self.itertimer(interval,
                                               name="Recovery.stats"):
            if self.in_recovery:
                now = monotonic()
                stats = self.active_stats()
                num_samples = len(self._processing_times)
                if stats and num_samples >= self.num_samples_required_for_estimate:
                    remaining_total = self.active_remaining_total()
                    self.log.info(
                        "Still fetching changelog topics for recovery, "
                        "estimated time remaining %s "
                        "(total remaining=%r):\n%s",
                        self.active_remaining_seconds(remaining_total),
                        remaining_total,
                        self._stats_to_logtable(
                            "Remaining for active recovery", stats),
                    )
                elif stats:
                    await self._verify_remaining(now, stats)
                else:
                    recovery_started_at = self._recovery_started_at
                    if recovery_started_at is None:
                        self.log.error(
                            "POSSIBLE INTERNAL ERROR: "
                            "Recovery marked as started but missing "
                            "self._recovery_started_at timestamp.")
                    else:
                        secs_since_started = now - recovery_started_at
                        if secs_since_started >= 30.0:
                            # This shouldn't happen, but we want to
                            # log an error in case it does.
                            self.log.error(
                                "POSSIBLE INTERNAL ERROR: "
                                "Recovery has no remaining offsets to fetch, "
                                "but we have spent %s waiting for the worker "
                                "to transition out of recovery state...",
                                humanize_seconds(secs_since_started),
                            )

    async def _verify_remaining(self, now: float,
                                stats: RecoveryStatsMapping) -> None:
        consumer = self.app.consumer
        active_events_received_at = self._active_events_received_at
        recovery_started_at = self._recovery_started_at
        if recovery_started_at is None:
            return  # we already log about this in _publish_stats
        secs_since_started = now - recovery_started_at

        last_flush_at = self._last_flush_at
        if last_flush_at is None:
            if secs_since_started >= self.flush_timeout_secs:
                self.log.warning(
                    "Recovery has not flushed buffers since "
                    "recovery startted (started %s). "
                    "Current total buffer size: %r",
                    humanize_seconds_ago(secs_since_started),
                    self._current_total_buffer_size(),
                )
        else:
            secs_since_last_flush = now - last_flush_at
            if secs_since_last_flush >= self.flush_timeout_secs:
                self.log.warning(
                    "Recovery has not flushed buffers in the last %r "
                    "seconds (last flush was %s). "
                    "Current total buffer size: %r",
                    self.flush_timeout_secs,
                    humanize_seconds_ago(secs_since_last_flush),
                    self._current_total_buffer_size(),
                )

        for tp in stats:
            await self.sleep(0)
            if self.should_stop:
                break
            if not self.in_recovery:
                break
            consumer.verify_recovery_event_path(now, tp)
            secs_since_started = now - recovery_started_at

            last_event_received = active_events_received_at.get(tp)
            if last_event_received is None:
                if secs_since_started >= self.event_timeout_secs:
                    self.log.warning(
                        "No event received for active tp %r since recovery "
                        "start (started %s)",
                        tp,
                        humanize_seconds_ago(secs_since_started),
                    )
                continue

            secs_since_received = now - last_event_received
            if secs_since_received >= self.event_timeout_secs:
                self.log.warning(
                    "No event received for active tp %r in the last %r "
                    "seconds (last event received %s)",
                    tp,
                    self.event_timeout_secs,
                    humanize_seconds_ago(secs_since_received),
                )

    def _current_total_buffer_size(self) -> int:
        return sum(len(buf) for buf in self.buffers.values())

    def _is_changelog_tp(self, tp: TP) -> bool:
        return tp.topic in self.tables.changelog_topics
class SimulationServer():
    @inject
    def __init__(self, port, time: Time, session: ClientSession,
                 authserver: Server, config: Config):
        self.items: Dict[str, Any] = {}
        self.config = config
        self.id_counter = 0
        self.upload_info: Dict[str, Any] = {}
        self.simulate_drive_errors = False
        self.simulate_out_of_drive_space = False
        self.error_code = 500
        self.match_errors = []
        self.last_error = False
        self.snapshots: Dict[str, Any] = {}
        self.snapshot_data: Dict[str, bytearray] = {}
        self.files: Dict[str, bytearray] = {}
        self.chunks = []
        self.settings: Dict[str, Any] = self.defaultSettings()
        self.client_id_hack = None
        self._snapshot_lock = asyncio.Lock()
        self._settings_lock = Lock()
        self._port = port
        self._ha_error = None
        self._entities = {}
        self._events = []
        self._attributes = {}
        self._notification = None
        self._time: FakeTime = time
        self._options = self.defaultOptions()
        self._username = "******"
        self._password = "******"
        self.lostPermission = []
        self.urls = []
        self.relative = True
        self.block_snapshots = False
        self.snapshot_in_progress = False
        self.drive_auth_code = "drive_auth_code"
        self._authserver = authserver
        self._upload_chunk_wait = Event()
        self._upload_chunk_trigger = Event()
        self.current_chunk = 1
        self.waitOnChunk = 0
        self.custom_drive_client_id = self.generateId(5)
        self.custom_drive_client_secret = self.generateId(5)
        self.supervisor_error = None
        self.drive_sleep = 0
        self.supervisor_sleep = 0

    def wasUrlRequested(self, pattern):
        for url in self.urls:
            if pattern in url:
                return True
        return False

    def blockSnapshots(self):
        self.block_snapshots = True

    def unBlockSnapshots(self):
        self.block_snapshots = False

    def setError(self, url_regx, attempts=0, status=500):
        self.match_errors.append({
            'url': url_regx,
            'attempts': attempts,
            'status': status
        })

    def defaultOptions(self):
        return {
            "max_snapshots_in_hassio": 4,
            "max_snapshots_in_google_drive": 4,
            "days_between_snapshots": 3,
            "use_ssl": False
        }

    def getEvents(self):
        return self._events.copy()

    def setHomeAssistantError(self, status_code):
        self._ha_error = status_code

    def getEntity(self, entity):
        return self._entities.get(entity)

    def clearEntities(self):
        self._entities = {}

    def getAttributes(self, attribute):
        return self._attributes.get(attribute)

    def getNotification(self):
        return self._notification

    def _reset(self) -> None:
        with self._settings_lock:
            self._ha_error = None
            self.items = {}
            self.upload_info = {}
            self.snapshots = {}
            self.snapshot_data = {}
            self.files = {}
            self._entities = {}
            self._attributes = {}
            self._notification = None
            self.settings = self.defaultSettings()
            self._options = self.defaultOptions()

    def getSetting(self, key):
        with self._settings_lock:
            return self.settings[key]

    def update(self, config):
        with self._settings_lock:
            self.settings.update(config)

    def defaultSettings(self):
        return {
            'snapshot_wait_time': 0,
            'snapshot_min_size': 1024 * 256 * 1,
            'snapshot_max_size': 1024 * 256 * 2,
            'ha_header': "test_header",
            "ha_version": "0.91.3",
            "ha_last_version": "0.91.2",
            "machine": "raspberrypi3",
            "ip_address": "172.30.32.1",
            "arch": "armv7",
            "image": "homeassistant/raspberrypi3-homeassistant",
            "custom": True,
            "drive_upload_error": None,
            "drive_upload_error_attempts": 0,
            "boot": True,
            "port": 8099,
            "ha_port": 1337,
            "ssl": False,
            "watchdog": True,
            "wait_boot": 600,
            "web_ui": "http://[HOST]:8099/",
            "ingress_url": "/index",
            "supervisor": "2.2.2",
            "homeassistant": "0.93.1",
            "hassos": "0.69.69",
            "hassio_error": None,
            "hassio_snapshot_error": None,
            "hostname": "localhost",
            "always_hard_lock": False,
            "supported_arch": [],
            "channel": "dev",
            "addon_slug": "self_slug",
            "drive_refresh_token": "",
            "drive_auth_token": "",
            "drive_upload_sleep": 0,
            "drive_all_error": None
        }

    def driveError(self) -> Any:
        if not self.simulate_drive_errors:
            return False
        if not self.last_error:
            self.last_error = True
            return self.error_code
        else:
            self.last_error = False
            return None

    async def readAll(self, request):
        data = bytearray()
        content = request.content
        while True:
            chunk, done = await content.readchunk()
            data.extend(chunk)
            if len(chunk) == 0:
                break
        return data

    def _checkDriveError(self, request: Request):
        if self.getSetting("drive_all_error"):
            raise HttpMultiException(self.getSetting("drive_all_error"))
        error = self.driveError()
        if error:
            raise HttpMultiException(error)
        for error in self.match_errors:
            if re.match(error['url'], str(request.url)):
                if error['attempts'] <= 0:
                    raise HttpMultiException(error['status'])
                else:
                    error['attempts'] = error['attempts'] - 1

    async def _checkDriveHeaders(self, request: Request):
        if self.drive_sleep > 0:
            await asyncio.sleep(self.drive_sleep)
        self._checkDriveError(request)
        if request.headers.get(
                "Authorization",
                "") != "Bearer " + self.getSetting('drive_auth_token'):
            raise HTTPUnauthorized()

    async def driveRefreshToken(self, request: Request):
        params = await request.post()
        if not self.checkClientIdandSecret(params['client_id'],
                                           params['client_secret']):
            raise HTTPUnauthorized()
        if params['refresh_token'] != self.getSetting('drive_refresh_token'):
            raise HTTPUnauthorized()
        if params['grant_type'] != 'refresh_token':
            raise HTTPUnauthorized()

        self.generateNewAccessToken()

        return json_response({
            'access_token': self.settings['drive_auth_token'],
            'expires_in': 3600,
            'token_type': 'doesn\'t matter'
        })

    def generateNewAccessToken(self):
        new_token = self.generateId(20)
        with self._settings_lock:
            self.settings['drive_auth_token'] = new_token

    def generateNewRefreshToken(self):
        new_token = self.generateId(20)
        with self._settings_lock:
            self.settings['drive_refresh_token'] = new_token

    async def driveAuthorize(self, request: Request):
        query = request.query
        if query.get('client_id') != self.config.get(
                Setting.DEFAULT_DRIVE_CLIENT_ID) and query.get(
                    'client_id') != self.custom_drive_client_id:
            raise HTTPUnauthorized()
        if query.get('scope') != 'https://www.googleapis.com/auth/drive.file':
            raise HTTPUnauthorized()
        if query.get('response_type') != 'code':
            raise HTTPUnauthorized()
        if query.get('include_granted_scopes') != 'true':
            raise HTTPUnauthorized()
        if query.get('access_type') != 'offline':
            raise HTTPUnauthorized()
        if 'state' not in query:
            raise HTTPUnauthorized()
        if 'redirect_uri' not in query:
            raise HTTPUnauthorized()
        if query.get('prompt') != 'consent':
            raise HTTPUnauthorized()
        url = URL(query.get('redirect_uri')).with_query({
            'code':
            self.drive_auth_code,
            'state':
            query.get('state')
        })
        raise HTTPSeeOther(str(url))

    async def driveToken(self, request: Request):
        data = await request.post()
        if data.get('redirect_uri') not in [
                "http://localhost:{}/drive/authorize".format(self._port),
                'urn:ietf:wg:oauth:2.0:oob'
        ]:
            raise HTTPUnauthorized()
        if data.get('grant_type') != 'authorization_code':
            raise HTTPUnauthorized()
        if not self.checkClientIdandSecret(data.get('client_id'),
                                           data.get('client_secret')):
            raise HTTPUnauthorized()
        if data.get('code') != self.drive_auth_code:
            raise HTTPUnauthorized()
        self.generateNewRefreshToken()
        return json_response({
            'access_token':
            self.getSetting('drive_auth_token'),
            'refresh_token':
            self.getSetting('drive_refresh_token'),
            'client_id':
            data.get('client_id'),
            'client_secret':
            self.config.get(Setting.DEFAULT_DRIVE_CLIENT_SECRET),
            'token_expiry':
            self.timeToRfc3339String(self._time.now()),
        })

    def checkClientIdandSecret(self, client_id: str,
                               client_secret: str) -> bool:
        if self.custom_drive_client_id == client_id and self.custom_drive_client_secret == client_secret:
            return True
        if client_id == self.config.get(
                Setting.DEFAULT_DRIVE_CLIENT_ID
        ) == client_id and client_secret == self.config.get(
                Setting.DEFAULT_DRIVE_CLIENT_SECRET):
            return True

        if self.client_id_hack is not None:
            if client_id == self.client_id_hack and client_secret == self.config.get(
                    Setting.DEFAULT_DRIVE_CLIENT_SECRET):
                return True
        return False

    def expireCreds(self):
        self.generateNewAccessToken()
        self.generateNewRefreshToken()

    def expireRefreshToken(self):
        self.generateNewRefreshToken()

    def resetDriveAuth(self):
        self.expireCreds()
        self.config.override(Setting.DEFAULT_DRIVE_CLIENT_ID,
                             self.generateId(5))
        self.config.override(Setting.DEFAULT_DRIVE_CLIENT_SECRET,
                             self.generateId(5))

    def getCurrentCreds(self):
        return Creds(self._time,
                     id=self.config.get(Setting.DEFAULT_DRIVE_CLIENT_ID),
                     expiration=self._time.now() + timedelta(hours=1),
                     access_token=self.getSetting("drive_auth_token"),
                     refresh_token=self.getSetting("drive_refresh_token"))

    async def reset(self, request: Request):
        self._reset()
        if isinstance(request, Request):
            self.update(request.query)
        if isinstance(request, Dict):
            self.update(request)

    async def uploadfile(self, request: Request):
        name: str = str(request.query.get("name", "test"))
        self.files[name] = await self.readAll(request)
        return Response(text="")

    async def readFile(self, request: Request):
        return self.serve_bytes(request,
                                self.files[request.query.get("name", "test")])

    def serve_bytes(self,
                    request: Request,
                    bytes: bytearray,
                    include_length: bool = True) -> Any:
        if "Range" in request.headers:
            # Do range request
            if not rangePattern.match(request.headers['Range']):
                raise HTTPBadRequest()

            numbers = intPattern.findall(request.headers['Range'])
            start = int(numbers[0])
            end = int(numbers[1])

            if start < 0:
                raise HTTPBadRequest()
            if start > end:
                raise HTTPBadRequest()
            if end > len(bytes) - 1:
                raise HTTPBadRequest()
            resp = Response(body=bytes[start:end + 1], status=206)
            resp.headers['Content-Range'] = "bytes {0}-{1}/{2}".format(
                start, end, len(bytes))
            if include_length:
                resp.headers["Content-length"] = str(len(bytes))
            return resp
        else:
            resp = Response(body=io.BytesIO(bytes))
            resp.headers["Content-length"] = str(len(bytes))
            return resp

    async def updateSettings(self, request: Request):
        data = await request.json()
        with self._settings_lock:
            for key in data:
                self.settings[key] = data[key]
            for key in request.query:
                self.settings[key] = request.query[key]
        return Response(text="updated")

    async def driveGetItem(self, request: Request):
        id = request.match_info.get('id')
        await self._checkDriveHeaders(request)
        if id not in self.items:
            raise HTTPNotFound
        if id in self.lostPermission:
            return Response(
                status=403,
                content_type="application/json",
                text='{"error": {"errors": [{"reason": "forbidden"}]}}')
        request_type = request.query.get("alt", "metadata")
        if request_type == "media":
            # return bytes
            item = self.items[id]
            if 'bytes' not in item:
                raise HTTPBadRequest
            return self.serve_bytes(request,
                                    item['bytes'],
                                    include_length=False)
        else:
            fields = request.query.get("fields", "id").split(",")
            return json_response(self.filter_fields(self.items[id], fields))

    async def driveUpdate(self, request: Request):
        id = request.match_info.get('id')
        await self._checkDriveHeaders(request)
        if id not in self.items:
            return HTTPNotFound
        update = await request.json()
        for key in update:
            if key in self.items[id] and isinstance(self.items[id][key], dict):
                self.items[id][key].update(update[key])
            else:
                self.items[id][key] = update[key]
        return Response()

    async def driveDelete(self, request: Request):
        id = request.match_info.get('id')
        await self._checkDriveHeaders(request)
        if id not in self.items:
            raise HTTPNotFound
        del self.items[id]
        return Response()

    async def driveQuery(self, request: Request):
        await self._checkDriveHeaders(request)
        query: str = request.query.get("q", "")
        fields = self.parseFields(request.query.get('fields', 'id'))
        if mimeTypeQueryPattern.match(query):
            ret = []
            mimeType = query[len("mimeType='"):-1]
            for item in self.items.values():
                if item.get('mimeType', '') == mimeType:
                    ret.append(self.filter_fields(item, fields))
            return json_response({'files': ret})
        elif parentsQueryPattern.match(query):
            ret = []
            parent = query[1:-len("' in parents")]
            if parent not in self.items:
                raise HTTPNotFound
            if parent in self.lostPermission:
                return Response(
                    status=403,
                    content_type="application/json",
                    text='{"error": {"errors": [{"reason": "forbidden"}]}}')
            for item in self.items.values():
                if parent in item.get('parents', []):
                    ret.append(self.filter_fields(item, fields))
            return json_response({'files': ret})
        elif len(query) == 0:
            ret = []
            for item in self.items.values():
                ret.append(self.filter_fields(item, fields))
            return json_response({'files': ret})
        else:
            raise HTTPBadRequest

    async def driveCreate(self, request: Request):
        await self._checkDriveHeaders(request)
        id = self.generateId(30)
        item = self.formatItem(await request.json(), id)
        self.items[id] = item
        return json_response({'id': item['id']})

    async def driveStartUpload(self, request: Request):
        if self.simulate_out_of_drive_space:
            return json_response(
                {"error": {
                    "errors": [{
                        "reason": "storageQuotaExceeded"
                    }]
                }},
                status=400)
        logging.getLogger().info("Drive start upload request")
        await self._checkDriveHeaders(request)
        if request.query.get('uploadType') != 'resumable':
            raise HTTPBadRequest()
        mimeType = request.headers.get('X-Upload-Content-Type', None)
        if mimeType is None:
            raise HTTPBadRequest()
        size = int(request.headers.get('X-Upload-Content-Length', -1))
        if size == -1:
            raise HTTPBadRequest()
        metadata = await request.json()
        id = self.generateId()

        # Validate parents
        if 'parents' in metadata:
            for parent in metadata['parents']:
                if parent not in self.items:
                    raise HTTPNotFound()
                if parent in self.lostPermission:
                    return Response(
                        status=403,
                        content_type="application/json",
                        text='{"error": {"errors": [{"reason": "forbidden"}]}}'
                    )
        self.upload_info['size'] = size
        self.upload_info['mime'] = mimeType
        self.upload_info['item'] = self.formatItem(metadata, id)
        self.upload_info['id'] = id
        self.upload_info['next_start'] = 0
        metadata['bytes'] = bytearray()
        metadata['size'] = size
        resp = Response()
        resp.headers['Location'] = "http://localhost:" + \
            str(self._port) + "/upload/drive/v3/files/progress/" + id
        return resp

    async def driveContinueUpload(self, request: Request):
        if self.waitOnChunk > 0:
            if self.current_chunk == self.waitOnChunk:
                self._upload_chunk_trigger.set()
                await self._upload_chunk_wait.wait()
            else:
                self.current_chunk += 1
        id = request.match_info.get('id')
        if (self.getSetting('drive_upload_sleep') > 0):
            await self._time.sleepAsync(self.getSetting('drive_upload_sleep'))
        await self._checkDriveHeaders(request)
        if self.upload_info.get('id', "") != id:
            raise HTTPBadRequest()
        chunk_size = int(request.headers['Content-Length'])
        info = request.headers['Content-Range']
        if resumeBytesPattern.match(info):
            resp = Response(status=308)
            if self.upload_info['next_start'] != 0:
                resp.headers['Range'] = "bytes=0-{0}".format(
                    self.upload_info['next_start'] - 1)
            return resp
        if not bytesPattern.match(info):
            raise HTTPBadRequest()
        numbers = intPattern.findall(info)
        start = int(numbers[0])
        end = int(numbers[1])
        total = int(numbers[2])
        if total != self.upload_info['size']:
            raise HTTPBadRequest()
        if start != self.upload_info['next_start']:
            raise HTTPBadRequest()
        if not (end == total - 1 or chunk_size % (256 * 1024) == 0):
            raise HTTPBadRequest()
        if end > total - 1:
            raise HTTPBadRequest()

        # get the chunk
        received_bytes = await self.readAll(request)

        # See if we shoudl fail the request
        if self.getSetting("drive_upload_error") is not None:
            if self.getSetting("drive_upload_error_attempts") <= 0:
                raise HttpMultiException(self.getSetting("drive_upload_error"))
            else:
                self.update({
                    "drive_upload_error_attempts":
                    self.getSetting("drive_upload_error_attempts") - 1
                })

        # validate the chunk
        if len(received_bytes) != chunk_size:
            raise HTTPBadRequest()

        if len(received_bytes) != end - start + 1:
            raise HTTPBadRequest()

        self.upload_info['item']['bytes'].extend(received_bytes)

        if len(self.upload_info['item']['bytes']) != end + 1:
            raise HTTPBadRequest()

        self.chunks.append(len(received_bytes))
        if end == total - 1:
            # upload is complete, so create the item
            self.items[self.upload_info['id']] = self.upload_info['item']
            return json_response({"id": self.upload_info['id']})
        else:
            # Return an incomplete response
            # For some reason, the tests like to stop right here
            resp = Response(status=308)
            self.upload_info['next_start'] = end + 1
            resp.headers['Range'] = "bytes=0-{0}".format(end)
            return resp

    # HASSIO METHODS BELOW
    async def _verifyHassioHeader(self, request) -> bool:
        if self.supervisor_sleep > 0:
            await asyncio.sleep(self.supervisor_sleep)
        if self.getSetting("hassio_error") is not None:
            raise HttpMultiException(self.getSetting("hassio_error"))
        self._verifyHeader(request, "Authorization",
                           "Bearer " + self.getSetting('ha_header'))

    def _verifyHaHeader(self, request) -> bool:
        if self._ha_error is not None:
            raise HttpMultiException(self._ha_error)
        self._verifyHeader(request, "Authorization",
                           "Bearer " + self.getSetting('ha_header'))

    def _verifyHeader(self, request, key: str, value: str) -> bool:
        if request.headers.get(key, None) != value:
            raise HTTPUnauthorized()

    def formatDataResponse(self, data: Any) -> str:
        return json_response({'result': 'ok', 'data': data})

    def checkForSupervisorError(self):
        if self.supervisor_error is not None:
            return Response(status=self.supervisor_error)
        return None

    def formatErrorResponse(self, error: str) -> str:
        return json_response({'result': error})

    async def hassioSnapshots(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return self.formatDataResponse(
            {'snapshots': list(self.snapshots.values())})

    async def hassioSupervisorInfo(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return self.formatDataResponse({"addons": list(all_addons).copy()})

    async def supervisorLogs(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return Response(body="Supervisor Log line 1\nSupervisor Log Line 2")

    async def coreLogs(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return Response(body="Core Log line 1\nCore Log Line 2")

    async def haInfo(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return self.formatDataResponse({
            "version":
            self.getSetting('ha_version'),
            "last_version":
            self.getSetting('ha_last_version'),
            "machine":
            self.getSetting('machine'),
            "ip_address":
            self.getSetting('ip_address'),
            "arch":
            self.getSetting('arch'),
            "image":
            self.getSetting('image'),
            "custom":
            self.getSetting('custom'),
            "boot":
            self.getSetting('boot'),
            "port":
            self.getSetting('ha_port'),
            "ssl":
            self.getSetting('ssl'),
            "watchdog":
            self.getSetting('watchdog'),
            "wait_boot":
            self.getSetting('wait_boot')
        })

    async def hassioNewFullSnapshot(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        if (self.block_snapshots or self.snapshot_in_progress
            ) and not self.getSetting('always_hard_lock'):
            raise HTTPBadRequest()
        input_json = {}
        try:
            input_json = await request.json()
        except:  # noqa: E722
            pass
        try:
            await self._snapshot_lock.acquire()
            self.snapshot_in_progress = True
            await self._verifyHassioHeader(request)
            error = self.getSetting("hassio_snapshot_error")
            if error is not None:
                raise HttpMultiException(error)

            seconds = int(
                request.query.get('seconds',
                                  self.getSetting('snapshot_wait_time')))
            date = self._time.now()
            size = int(
                random.uniform(float(self.getSetting('snapshot_min_size')),
                               float(self.getSetting('snapshot_max_size'))))
            slug = self.generateId(8)
            name = input_json.get('name', "Default name")
            password = input_json.get('password', None)
            if seconds > 0:
                await asyncio.sleep(seconds)

            data = createSnapshotTar(slug, name, date, size, password=password)
            snapshot_info = parseSnapshotInfo(data)
            self.snapshots[slug] = snapshot_info
            self.snapshot_data[slug] = bytearray(data.getbuffer())
            return self.formatDataResponse({"slug": slug})
        finally:
            self.snapshot_in_progress = False
            self._snapshot_lock.release()

    async def hassioNewPartialSnapshot(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        if (self.block_snapshots or self.snapshot_in_progress
            ) and not self.getSetting('always_hard_lock'):
            raise HTTPBadRequest()
        input_json = await request.json()
        try:
            await self._snapshot_lock.acquire()
            self.snapshot_in_progress = True
            await self._verifyHassioHeader(request)
            seconds = int(
                request.query.get('seconds',
                                  self.getSetting('snapshot_wait_time')))
            date = self._time.now()
            size = int(
                random.uniform(float(self.getSetting('snapshot_min_size')),
                               float(self.getSetting('snapshot_max_size'))))
            slug = self.generateId(8)
            name = input_json['name']
            password = input_json.get('password', None)
            if seconds > 0:
                await asyncio.sleep(seconds)

            data = createSnapshotTar(slug,
                                     name,
                                     date,
                                     size,
                                     included_folders=input_json['folders'],
                                     included_addons=input_json['addons'],
                                     password=password)
            snapshot_info = parseSnapshotInfo(data)
            self.snapshots[slug] = snapshot_info
            self.snapshot_data[slug] = bytearray(data.getbuffer())
            return self.formatDataResponse({"slug": slug})
        finally:
            self.snapshot_in_progress = False
            self._snapshot_lock.release()

    async def uploadNewSnapshot(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        try:
            received_bytes = await self.readAll(request)
            info = parseSnapshotInfo(BytesIO(received_bytes))
            self.snapshots[info['slug']] = info
            self.snapshot_data[info['slug']] = received_bytes
            return self.formatDataResponse({"slug": info['slug']})
        except Exception as e:
            print(str(e))
            return self.formatErrorResponse("Bad snapshot")

    async def hassioDelete(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        slug = request.match_info.get('slug')
        await self._verifyHassioHeader(request)
        if slug not in self.snapshots:
            raise HTTPNotFound()
        del self.snapshots[slug]
        del self.snapshot_data[slug]
        return self.formatDataResponse("deleted")

    async def hassioSnapshotInfo(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        slug = request.match_info.get('slug')
        await self._verifyHassioHeader(request)
        if slug not in self.snapshots:
            raise HTTPNotFound()
        return self.formatDataResponse(self.snapshots[slug])

    async def hassioSnapshotDownload(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        slug = request.match_info.get('slug')
        await self._verifyHassioHeader(request)
        if slug not in self.snapshot_data:
            raise HTTPNotFound()
        return self.serve_bytes(request, self.snapshot_data[slug])

    async def hassioSelfInfo(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return self.formatDataResponse({
            "webui":
            self.getSetting('web_ui'),
            'ingress_url':
            self.getSetting('ingress_url'),
            "slug":
            self.getSetting('addon_slug'),
            "options":
            self._options
        })

    async def hassioInfo(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        return self.formatDataResponse({
            "supervisor":
            self.getSetting('supervisor'),
            "homeassistant":
            self.getSetting('homeassistant'),
            "hassos":
            self.getSetting('hassos'),
            "hostname":
            self.getSetting('hostname'),
            "machine":
            self.getSetting('machine'),
            "arch":
            self.getSetting('arch'),
            "supported_arch":
            self.getSetting('supported_arch'),
            "channel":
            self.getSetting('channel')
        })

    async def hassioAuthenticate(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        input_json = await request.json()
        if input_json.get("username") != self._username or input_json.get(
                "password") != self._password:
            raise HTTPBadRequest()
        return self.formatDataResponse({})

    async def haStateUpdate(self, request: Request):
        entity = request.match_info.get('entity')
        self._verifyHaHeader(request)
        json = await request.json()
        self._entities[entity] = json['state']
        self._attributes[entity] = json['attributes']
        return Response()

    async def haEventUpdate(self, request: Request):
        name = request.match_info.get('name')
        self._verifyHaHeader(request)
        self._events.append((name, await request.json()))
        return Response()

    async def createNotification(self, request: Request):
        self._verifyHaHeader(request)
        notification = await request.json()
        print("Created notification with: {}".format(notification))
        self._notification = notification.copy()
        return Response()

    async def dismissNotification(self, request: Request):
        self._verifyHaHeader(request)
        print("Dismissed notification with: {}".format(await request.json()))
        self._notification = None
        return Response()

    async def hassioUpdateOptions(self, request: Request):
        if self.checkForSupervisorError() is not None:
            return self.checkForSupervisorError()
        await self._verifyHassioHeader(request)
        self._options = (await request.json())['options'].copy()
        return self.formatDataResponse({})

    async def slugRedirect(self, request: Request):
        raise HTTPSeeOther("https://localhost:" +
                           str(self.config.get(Setting.INGRESS_PORT)))

    @middleware
    async def error_middleware(self, request: Request, handler):
        self.urls.append(str(request.url))
        for error in self.match_errors:
            if re.match(error['url'], str(request.url)):
                if error['attempts'] <= 0:
                    await self.readAll(request)
                    return Response(status=error['status'])
                else:
                    error['attempts'] = error['attempts'] - 1
        try:
            resp = await handler(request)
            return resp
        except Exception as ex:
            await self.readAll(request)
            if isinstance(ex, HttpMultiException):
                return Response(status=ex.status_code)
            elif isinstance(ex, HTTPException):
                raise
            else:
                logger.printException(ex)
            return json_response(str(ex), status=502)

    def createApp(self):
        app = Application(middlewares=[self.error_middleware])
        app.add_routes(self.routes())
        self._authserver.buildApp(app)
        return app

    async def start(self, port):
        self.runner = aiohttp.web.AppRunner(self.createApp())
        await self.runner.setup()
        site = aiohttp.web.TCPSite(self.runner, "0.0.0.0", port=port)
        await site.start()

    async def stop(self):
        await self.runner.shutdown()
        await self.runner.cleanup()

    def toggleBlockSnapshot(self, request: Request):
        self.snapshot_in_progress = not self.snapshot_in_progress
        resp = "Blocking" if self.snapshot_in_progress else "Not Blocking"
        return Response(text=resp)

    def routes(self):
        return [
            post('/addons/self/options', self.hassioUpdateOptions),
            post("/homeassistant/api/services/persistent_notification/dismiss",
                 self.dismissNotification),
            post("/homeassistant/api/services/persistent_notification/create",
                 self.createNotification),
            post("/homeassistant/api/events/{name}", self.haEventUpdate),
            post("/homeassistant/api/states/{entity}", self.haStateUpdate),
            post('/auth', self.hassioAuthenticate),
            get('/auth', self.hassioAuthenticate),
            get('/info', self.hassioInfo),
            get('/addons/self/info', self.hassioSelfInfo),
            get('/snapshots/{slug}/download', self.hassioSnapshotDownload),
            get('/snapshots/{slug}/info', self.hassioSnapshotInfo),
            post('/snapshots/{slug}/remove', self.hassioDelete),
            post('/snapshots/new/upload', self.uploadNewSnapshot),
            get('/snapshots/new/upload', self.uploadNewSnapshot),
            get('/debug/toggleblock', self.toggleBlockSnapshot),
            post('/snapshots/new/partial', self.hassioNewPartialSnapshot),
            post('/snapshots/new/full', self.hassioNewFullSnapshot),
            get('/snapshots/new/full', self.hassioNewFullSnapshot),
            get('/homeassistant/info', self.haInfo),
            get('/supervisor/info', self.hassioSupervisorInfo),
            get('/supervisor/logs', self.supervisorLogs),
            get('/core/logs', self.coreLogs),
            get('/snapshots', self.hassioSnapshots),
            put('/upload/drive/v3/files/progress/{id}',
                self.driveContinueUpload),
            post('/upload/drive/v3/files/', self.driveStartUpload),
            post('/drive/v3/files/', self.driveCreate),
            get('/drive/v3/files/', self.driveQuery),
            delete('/drive/v3/files/{id}/', self.driveDelete),
            patch('/drive/v3/files/{id}/', self.driveUpdate),
            get('/drive/v3/files/{id}/', self.driveGetItem),
            post('/updatesettings', self.updateSettings),
            get('/readfile', self.readFile),
            post('/uploadfile', self.uploadfile),
            post('/doareset', self.reset),
            post('/oauth2/v4/token', self.driveRefreshToken),
            get('/o/oauth2/v2/auth', self.driveAuthorize),
            post('/token', self.driveToken),
            get('/hassio/ingress/self_slug', self.slugRedirect)
        ]

    def generateId(self, length: int = 30) -> Any:
        self.id_counter += 1
        ret = str(self.id_counter)
        return ret + ''.join(map(lambda x: str(x), range(0,
                                                         length - len(ret))))
        # return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(length))

    def timeToRfc3339String(self, time) -> Any:
        return time.strftime("%Y-%m-%dT%H:%M:%SZ")

    def formatItem(self, base, id):
        base['capabilities'] = {
            'canAddChildren': True,
            'canListChildren': True,
            'canDeleteChildren': True
        }
        base['trashed'] = False
        base['id'] = id
        base['modifiedTime'] = self.timeToRfc3339String(self._time.now())
        return base

    def parseFields(self, source: str):
        fields = []
        for field in source.split(","):
            if field.startswith("files("):
                fields.append(field[6:])
            elif field.endswith(")"):
                fields.append(field[:-1])
            else:
                fields.append(field)
        return fields

    def filter_fields(self, item: Dict[str, Any], fields) -> Dict[str, Any]:
        ret = {}
        for field in fields:
            if field in item:
                ret[field] = item[field]
        return ret
예제 #41
0
class SessionBase(asyncio.Protocol):
    """Base class of networking sessions.

    There is no client / server distinction other than who initiated
    the connection.

    To initiate a connection to a remote server pass host, port and
    proxy to the constructor, and then call create_connection().  Each
    successful call should have a corresponding call to close().

    Alternatively if used in a with statement, the connection is made
    on entry to the block, and closed on exit from the block.
    """

    max_errors = 10

    def __init__(self, *, framer=None, loop=None):
        self.framer = framer or self.default_framer()
        self.loop = loop or asyncio.get_event_loop()
        self.logger = logging.getLogger(self.__class__.__name__)
        self.transport = None
        # Set when a connection is made
        self._address = None
        self._proxy_address = None
        # For logger.debug messages
        self.verbosity = 0
        # Cleared when the send socket is full
        self._can_send = Event()
        self._can_send.set()
        self._pm_task = None
        self._task_group = TaskGroup(self.loop)
        # Force-close a connection if a send doesn't succeed in this time
        self.max_send_delay = 60
        # Statistics.  The RPC object also keeps its own statistics.
        self.start_time = time.perf_counter()
        self.errors = 0
        self.send_count = 0
        self.send_size = 0
        self.last_send = self.start_time
        self.recv_count = 0
        self.recv_size = 0
        self.last_recv = self.start_time
        self.last_packet_received = self.start_time

    async def _limited_wait(self, secs):
        try:
            await asyncio.wait_for(self._can_send.wait(), secs)
        except asyncio.TimeoutError:
            self.abort()
            raise asyncio.TimeoutError(f'task timed out after {secs}s')

    async def _send_message(self, message):
        if not self._can_send.is_set():
            await self._limited_wait(self.max_send_delay)
        if not self.is_closing():
            framed_message = self.framer.frame(message)
            self.send_size += len(framed_message)
            self.send_count += 1
            self.last_send = time.perf_counter()
            if self.verbosity >= 4:
                self.logger.debug(f'Sending framed message {framed_message}')
            self.transport.write(framed_message)

    def _bump_errors(self):
        self.errors += 1
        if self.errors >= self.max_errors:
            # Don't await self.close() because that is self-cancelling
            self._close()

    def _close(self):
        if self.transport:
            self.transport.close()

    # asyncio framework
    def data_received(self, framed_message):
        """Called by asyncio when a message comes in."""
        self.last_packet_received = time.perf_counter()
        if self.verbosity >= 4:
            self.logger.debug(f'Received framed message {framed_message}')
        self.recv_size += len(framed_message)
        self.framer.received_bytes(framed_message)

    def pause_writing(self):
        """Transport calls when the send buffer is full."""
        if not self.is_closing():
            self._can_send.clear()
            self.transport.pause_reading()

    def resume_writing(self):
        """Transport calls when the send buffer has room."""
        if not self._can_send.is_set():
            self._can_send.set()
            self.transport.resume_reading()

    def connection_made(self, transport):
        """Called by asyncio when a connection is established.

        Derived classes overriding this method must call this first."""
        self.transport = transport
        # This would throw if called on a closed SSL transport.  Fixed
        # in asyncio in Python 3.6.1 and 3.5.4
        peer_address = transport.get_extra_info('peername')
        # If the Socks proxy was used then _address is already set to
        # the remote address
        if self._address:
            self._proxy_address = peer_address
        else:
            self._address = peer_address
        self._pm_task = self.loop.create_task(self._receive_messages())

    def connection_lost(self, exc):
        """Called by asyncio when the connection closes.

        Tear down things done in connection_made."""
        self._address = None
        self.transport = None
        self._task_group.cancel()
        if self._pm_task:
            self._pm_task.cancel()
        # Release waiting tasks
        self._can_send.set()

    # External API
    def default_framer(self):
        """Return a default framer."""
        raise NotImplementedError

    def peer_address(self):
        """Returns the peer's address (Python networking address), or None if
        no connection or an error.

        This is the result of socket.getpeername() when the connection
        was made.
        """
        return self._address

    def peer_address_str(self):
        """Returns the peer's IP address and port as a human-readable
        string."""
        if not self._address:
            return 'unknown'
        ip_addr_str, port = self._address[:2]
        if ':' in ip_addr_str:
            return f'[{ip_addr_str}]:{port}'
        else:
            return f'{ip_addr_str}:{port}'

    def is_closing(self):
        """Return True if the connection is closing."""
        return not self.transport or self.transport.is_closing()

    def abort(self):
        """Forcefully close the connection."""
        if self.transport:
            self.transport.abort()

    # TODO: replace with synchronous_close
    async def close(self, *, force_after=30):
        """Close the connection and return when closed."""
        self._close()
        if self._pm_task:
            with suppress(CancelledError):
                await asyncio.wait([self._pm_task], timeout=force_after)
                self.abort()
                await self._pm_task

    def synchronous_close(self):
        self._close()
        if self._pm_task and not self._pm_task.done():
            self._pm_task.cancel()
예제 #42
0
class Controller(AbstractController):
    def __init__(self,
                 cam_index: int = 0,
                 lang: LangEnum = LangEnum.ENG,
                 log_handlers: [StreamHandler] = None):
        self._logger = getLogger(__name__)
        if log_handlers:
            for h in log_handlers:
                self._logger.addHandler(h)
        self._logger.debug("Initializing")
        self.cam_index = cam_index
        cam_name = "CAM_" + str(self.cam_index)
        view = CamView(cam_name, log_handlers)
        super().__init__(view)
        self.view.show_initialization()
        self.view.set_config_active(False)
        # TODO: Get logging in here. See https://docs.python.org/3/howto/logging-cookbook.html find multiprocessing.
        self._model_msg_pipe, msg_pipe = Pipe()  # For messages/commands.
        self._model_image_pipe, img_pipe = Pipe(False)  # For images.
        self._model = Process(target=CamModel,
                              args=(msg_pipe, img_pipe, self.cam_index))
        self._switcher = {
            defs.ModelEnum.FAILURE: self.err_cleanup,
            defs.ModelEnum.CUR_FPS: self._update_view_fps,
            defs.ModelEnum.CUR_RES: self._update_view_resolution,
            defs.ModelEnum.CLEANUP: self._set_model_cleaned,
            defs.ModelEnum.STOP: self._set_saved,
            defs.ModelEnum.STAT_UPD: self._show_init_progress,
            defs.ModelEnum.START: self._finalize
        }
        self._stop = TEvent()
        self._loop = get_running_loop()
        self._model_cleaned = Event()
        self._ended = Event()
        self._executor = ThreadPoolExecutor(2)
        self._loop.run_in_executor(self._executor, self._handle_pipe)
        self._update_feed_flag = TEvent()
        self._update_feed_flag.set()
        self._handle_pipe_flag = TEvent()
        self._handle_pipe_flag.set()
        self._model.start()
        self.send_msg_to_model((defs.ModelEnum.INITIALIZE, None))
        self.set_lang(lang)
        self._res_list = list()
        self._setup_handlers()
        self._exp_info_for_later = [str(), str(),
                                    0]  # path, cond_name, block_num
        self._create_exp_later_task = None
        self._start_exp_later_task = None
        self._cleaning = False
        self._initialized = Event()
        self._running = False
        self._logger.debug("Initialized")

    def set_lang(self, lang: LangEnum) -> None:
        """
        Set the language for this device.
        :param lang: The language to use.
        :return None:
        """
        self._logger.debug("running")
        self.view.language = lang
        self.send_msg_to_model((defs.ModelEnum.LANGUAGE, lang))
        self._logger.debug("done")

    async def cleanup(self, discard: bool = False) -> None:
        """
        Cleanup this object and prep for app closure.
        :param discard: Quit without saving.
        :return None:
        """
        self._logger.debug("running")
        self._cleaning = True
        self.send_msg_to_model((defs.ModelEnum.CLEANUP, discard))
        await self._model_cleaned.wait()
        self._stop.set()
        if self._model.is_alive():
            self._model.join()
        self._ended.set()
        self.view.save_window_state()
        self._cleaning = False
        self._logger.debug("done")

    def await_saved(self) -> futures:
        """
        Signal main app that this device data has been saved.
        :return futures: Event to signal saving done.
        """
        return await_event(self.saved)

    def create_exp(self, path: str, cond_name: str) -> None:
        """
        Handle experiment created for this device.
        :param path: The path to use to save data.
        :param cond_name: The optional condition name for this experiment.
        :return None:
        """
        self._logger.debug("running")
        self._exp_info_for_later[0] = path
        self._exp_info_for_later[1] = cond_name
        if self._initialized.is_set():
            self.send_msg_to_model((defs.ModelEnum.COND_NAME, cond_name))
            self.view.set_config_active(False)
            self.send_msg_to_model((defs.ModelEnum.START, path))
            self._running = True
            self.saved.clear()
        else:
            if self._create_exp_later_task is not None:
                self._create_exp_later_task.cancel()
            self._create_exp_later_task = create_task(self._create_exp_later())
        self._logger.debug("done")

    async def _create_exp_later(self) -> None:
        """
        Wait until camera is initialized, then use create exp
        :param path:
        :param cond_name:
        :return:
        """
        await self._initialized.wait()
        self.create_exp(self._exp_info_for_later[0],
                        self._exp_info_for_later[1])

    def end_exp(self) -> None:
        """
        Handle experiment ended for this device.
        :return None:
        """
        self._logger.debug("running")
        if self._running:
            self.send_msg_to_model((defs.ModelEnum.STOP, None))
            self.view.set_config_active(True)
            self.send_msg_to_model((defs.ModelEnum.BLOCK_NUM, 0))
            self._running = False
            if self._create_exp_later_task is not None:
                self._create_exp_later_task.cancel()
                self._create_exp_later_task = None
        self._logger.debug("done")

    def start_exp(self, block_num: int, cond_name: str) -> None:
        """
        Handle start exp signal for this camera.
        :param block_num: The current block number.
        :param cond_name: The name for this part of the experiment.
        :return None:
        """
        self._logger.debug("running")
        self._exp_info_for_later[1] = cond_name
        self._exp_info_for_later[2] = block_num
        if self._initialized.is_set():
            self.send_msg_to_model((defs.ModelEnum.BLOCK_NUM, block_num))
            self.send_msg_to_model((defs.ModelEnum.COND_NAME, cond_name))
            self.send_msg_to_model((defs.ModelEnum.EXP_STATUS, True))
        else:
            if self._start_exp_later_task is not None:
                self._start_exp_later_task.cancel()
            self._exp_info_for_later[1] = cond_name
            self._exp_info_for_later[2] = block_num
            self._start_exp_later_task = create_task(self._start_exp_later())
        self._logger.debug("done")

    async def _start_exp_later(self) -> None:
        """
        Wait until camera is initialized, then use start exp
        :param block_num:
        :param cond_name:
        :return:
        """
        await self._initialized.wait()
        self.start_exp(self._exp_info_for_later[2],
                       self._exp_info_for_later[1])

    def stop_exp(self) -> None:
        """
        Alert this camera when experiment stops.
        :return None:
        """
        self._logger.debug("running")
        self.send_msg_to_model((defs.ModelEnum.EXP_STATUS, False))
        if self._create_exp_later_task is not None:
            self._create_exp_later_task.cancel()
            self._create_exp_later_task = None
        self._logger.debug("done")

    def update_keyflag(self, flag: str) -> None:
        """
        Handle keflag changes for this camera.
        :param flag: The new flag.
        :return None:
        """
        self._logger.debug("running")
        self.send_msg_to_model((defs.ModelEnum.KEYFLAG, flag))
        self._logger.debug("done")

    def update_cond_name(self, name: str) -> None:
        """
        Update condition name for this device.
        :param name: The new condition name.
        :return None:
        """
        self._logger.debug("running")
        self.send_msg_to_model((defs.ModelEnum.COND_NAME, name))
        self._logger.debug("done")

    def update_resolution(self) -> None:
        """
        Get resolution selection from View and pass to model.
        :return None:
        """
        self._logger.debug("running")
        cur_res = self.view.resolution
        for res in self._res_list:
            if res[0] == cur_res:
                self.send_msg_to_model((defs.ModelEnum.SET_RES, res[1]))
                break
        self._logger.debug("done")

    def update_fps(self) -> None:
        """
        Get fps selection from View and pass to model.
        :return None:
        """
        self._logger.debug("running")
        new_fps = float(self.view.fps)
        self.send_msg_to_model((defs.ModelEnum.SET_FPS, new_fps))
        self._logger.debug("done")

    def update_show_feed(self) -> None:
        """
        Get show feed bool from view and pass to model.
        :return None:
        """
        self._logger.debug("running")
        if not self.view.use_feed or not self.view.use_cam:
            self._update_feed_flag.clear()
            self.view.update_image(msg="No Feed")
        elif self.view.use_feed and self.view.use_cam:
            self._update_feed_flag.set()
        self.send_msg_to_model((defs.ModelEnum.SET_USE_FEED, self.view.use_feed
                                and self.view.use_cam))
        self._logger.debug("done")

    def update_use_cam(self) -> None:
        """
        Set this camera active or inactive.
        :return None:
        """
        self._logger.debug("running")
        self.send_msg_to_model((defs.ModelEnum.SET_USE_CAM, self.view.use_cam))
        self.update_show_feed()
        self._logger.debug("done")

    def update_use_overlay(self) -> None:
        """
        Toggle whether overlay is being used on this camera.
        :return None:
        """
        self.send_msg_to_model((defs.ModelEnum.OVERLAY, self.view.use_overlay))

    def get_index(self) -> int:
        """
        Get this camera index.
        :return int: The camera index.
        """
        return self.cam_index

    def await_ended(self) -> futures:
        """
        Signal when there is a connect event.
        :return futures: If the flag is set.
        """
        return await_event(self._ended)

    def err_cleanup(self) -> None:
        """
        Handle cleanup when camera fails.
        :return None:
        """
        self._logger.debug("running")
        self._logger.warning("Camera error occurred.")
        if not self._cleaning:
            create_task(self.cleanup(True))
        self._logger.debug("done")

    def _set_saved(self) -> None:
        """
        Set saved signal.
        :return None:
        """
        self._loop.call_soon_threadsafe(self.saved.set)

    def _handle_pipe(self) -> None:
        """
        Handle msgs from model.
        :return None:
        """
        self._logger.debug("running")
        try:
            while not self._stop.isSet():
                if self._model_msg_pipe.poll():
                    msg = self._model_msg_pipe.recv()
                    if msg[0] in self._switcher.keys():
                        if msg[1] is not None:
                            self._loop.call_soon_threadsafe(
                                self._switcher[msg[0]], msg[1])
                        else:
                            self._loop.call_soon_threadsafe(
                                self._switcher[msg[0]])
                sleep(1)
        except BrokenPipeError as bpe:
            pass
        except OSError as ose:
            pass
        except Exception as e:
            raise e

    def _update_feed(self) -> None:
        """
        Update view with latest image from camera.
        :return None:
        """
        self._logger.debug("running")
        try:
            while not self._stop.isSet():
                next_image = None
                while self._model_image_pipe.poll():
                    next_image = self._model_image_pipe.recv()
                if next_image is not None and self._update_feed_flag.isSet():
                    converted_image = self.convert_image_to_qt_format(
                        next_image)
                    self._loop.call_soon_threadsafe(self.view.update_image,
                                                    converted_image)
                sleep(.008)
        except BrokenPipeError as bpe:
            pass
        except OSError as ose:
            pass

    def _show_init_progress(self, progress: int) -> None:
        """
        Update user on camera initialization progress.
        :param progress: The latest progress update.
        :return None:
        """
        self.view.update_init_bar(progress)

    def _finalize(self, init_results: list) -> None:
        """
        Tell model to start. Tell view to show images.
        :param init_results: List of resolutions supported by the camera.
        :return None:
        """
        self._logger.debug("running")
        self._loop.run_in_executor(self._executor, self._update_feed)
        self.send_msg_to_model((defs.ModelEnum.SET_USE_CAM, True))
        self.send_msg_to_model((defs.ModelEnum.SET_USE_FEED, True))
        fps = init_results[0]
        fps_list = [str(x) for x in range(1, fps + 1)]
        res_list = init_results[1]
        self._res_list = [((str(x[0]) + ", " + str(x[1])), x)
                          for x in res_list]
        self.view.resolution_list = [x[0] for x in self._res_list]
        self.view.fps_list = fps_list
        self.send_msg_to_model((defs.ModelEnum.SET_FPS, fps))
        self.send_msg_to_model((defs.ModelEnum.GET_FPS, None))
        self.send_msg_to_model((defs.ModelEnum.GET_RES, None))
        self.view.set_config_active(True)
        self.view.show_images()
        self._initialized.set()
        self._logger.debug("done")

    def _update_view_fps(self, new_fps: int) -> None:
        """
        Update view object fps display with new value.
        :param new_fps: The new value.
        :return None:
        """
        self._logger.debug("running")
        new_fps = str(new_fps)
        self.view.fps = new_fps
        self._logger.debug("done")

    def _update_view_resolution(self, new_resolution: tuple) -> None:
        """
        Update view object resolution display with new value.
        :param new_resolution: The new value.
        :return None:
        """
        self._logger.debug("running")
        for res in self._res_list:
            if res[1] == new_resolution:
                self.view.resolution = res[0]
        self._logger.debug("done")

    def _setup_handlers(self) -> None:
        """
        Connect handlers to view object.
        :return None:
        """
        self._logger.debug("running")
        self.view.set_fps_selector_handler(self.update_fps)
        self.view.set_resolution_selector_handler(self.update_resolution)
        self.view.set_show_feed_button_handler(self.update_show_feed)
        self.view.set_use_cam_button_handler(self.update_use_cam)
        self.view.set_use_overlay_button_handler(self.update_use_overlay)
        self._logger.debug("done")

    def _set_model_cleaned(self) -> None:
        """
        Set flag that model is done with cleanup.
        :return None:
        """
        self._logger.debug("running")
        self._model_cleaned.set()
        self._logger.debug("done")

    def send_msg_to_model(self, msg) -> None:
        """
        A wrapper for pipe.send()
        :param msg:
        :return:
        """
        try:
            self._model_msg_pipe.send(msg)
        except BrokenPipeError as bpe:
            pass
        except OSError as ose:
            pass
        except Exception as e:
            raise e

    @staticmethod
    def convert_image_to_qt_format(image: ndarray) -> QPixmap:
        """
        Convert image to suitable format for display in Qt.
        :param image: The image to convert.
        :return QPixmap: The converted image.
        """
        rgb_image = cvtColor(image, COLOR_BGR2RGB)
        h, w, ch = rgb_image.shape
        res = QImage(rgb_image.data, w, h, ch * w, QImage.Format_RGB888)
        ret = QPixmap.fromImage(res)
        return ret
예제 #43
0
파일: actor.py 프로젝트: iakinsey/illume
class Actor(object):

    """
    Main actor model.

    Args:
        inbox (GeneratorQueue): Inbox to consume from.
        outbox (GeneratorQueue): Outbox to publish to.
        loop (GeneratorQueue): Event loop.
    """

    running = False
    _force_stop = False

    def __init__(self, inbox, outbox, loop=None):
        self.inbox = inbox
        self.outbox = outbox

        if not loop:
            loop = get_event_loop()

        self._loop = loop
        self._pause_lock = Lock(loop=self._loop)
        self._stop_event = Event(loop=self._loop)
        self._test = None
        self.__testy = None

        self.on_init()

    @property
    def paused(self):
        """Indicate if actor is paused."""
        return self._pause_lock.locked()

    async def start(self):
        """Main public entry point to start the actor."""
        await self.initialize()
        await self._start()

    async def initialize(self):
        """Initialize the actor before starting."""
        await self.on_start()

        if self._force_stop:
            return

        if not self.running:
            self.running = True

    async def _start(self):
        """Run the event loop and force the on_stop event."""
        try:
            await self._run()
        finally:
            await self.on_stop()

    async def resume(self):
        """Resume the actor."""
        await self.on_resume()
        self._pause_lock.release()

    async def pause(self):
        """Pause the actor."""
        await self._pause_lock.acquire()
        await self.on_pause()

    async def _block_if_paused(self):
        """Block on the pause lock."""
        if self.paused:
            await self._pause_lock.acquire()
            await self._pause_lock.release()

    async def _run(self):
        """Main event loop."""
        while self.running:
            await self._block_if_paused()

            await self._process()

    async def publish(self, data):
        """Push data to the outbox."""
        await self.outbox.put(data)

    async def stop(self):
        """Stop the actor."""
        self.inbox = None
        self.outbox = None
        self.running = False
        self._force_stop = True

        self._stop_event.set()

        try:
            self._pause_lock.release()
        except RuntimeError:
            pass

    async def _process(self):
        """Process incoming messages."""
        if not self.inbox:
            return

        pending = {self.inbox.get(), self._stop_event.wait()}
        result = await get_first_completed(pending, self._loop)

        if self.running:
            await self.on_message(result)

    async def on_message(self, data):
        """Called when the actor receives a message."""
        raise NotImplementedError

    def on_init(self):
        """Called after the actor class is instantiated."""
        pass

    async def on_start(self):
        """Called before the actor starts ingesting the inbox."""
        pass

    async def on_stop(self):
        """Called after actor dies."""
        pass

    async def on_pause(self):
        """Called before the actor is paused."""
        pass

    async def on_resume(self):
        """Called before the actor is resumed."""
        pass
예제 #44
0
파일: recovery.py 프로젝트: smaxtec/faust
 def signal_recovery_start(self) -> Event:
     """Event used to signal that recovery has started."""
     if self._signal_recovery_start is None:
         self._signal_recovery_start = Event(loop=self.loop)
     return self._signal_recovery_start
예제 #45
0
파일: recovery.py 프로젝트: smaxtec/faust
 def signal_recovery_end(self) -> Event:
     """Event used to signal that recovery has ended."""
     if self._signal_recovery_end is None:
         self._signal_recovery_end = Event(loop=self.loop)
     return self._signal_recovery_end
예제 #46
0
파일: channel.py 프로젝트: tbug/aiochannel
class Channel(object):
    """
        A Channel is a closable queue. A Channel is considered "finished" when
        it is closed and drained (unlike a queue which is "finished" when the queue
        is empty)
    """

    def __init__(self, maxsize=0, *, loop=None):
        if loop is None:
            self._loop = get_event_loop()
        else:
            self._loop = loop

        if not isinstance(maxsize, int) or maxsize < 0:
            raise TypeError("maxsize must be an integer >= 0 (default is 0)")
        self._maxsize = maxsize

        # Futures.
        self._getters = deque()
        self._putters = deque()

        # "finished" means channel is closed and drained
        self._finished = Event(loop=self._loop)
        self._close = Event(loop=self._loop)

        self._init()

    def _init(self):
        self._queue = deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    def _wakeup_next(self, waiters):
        # Wake up the next waiter (if any) that isn't cancelled.
        while waiters:
            waiter = waiters.popleft()
            if not waiter.done():
                waiter.set_result(None)
                break

    def __repr__(self):
        return '<{} at {:#x} maxsize={!r} qsize={!r}>'.format(
            type(self).__name__, id(self), self._maxsize, self.qsize())

    def __str__(self):
        return '<{} maxsize={!r} qsize={!r}>'.format(
            type(self).__name__, self._maxsize, self.qsize())

    def qsize(self):
        """Number of items in the channel buffer."""
        return len(self._queue)

    @property
    def maxsize(self):
        """Number of items allowed in the channel buffer."""
        return self._maxsize

    def empty(self):
        """Return True if the channel is empty, False otherwise."""
        return not self._queue

    def full(self):
        """Return True if there are maxsize items in the channel.
        Note: if the Channel was initialized with maxsize=0 (the default),
        then full() is never True.
        """
        if self._maxsize <= 0:
            return False
        else:
            return self.qsize() >= self._maxsize

    @coroutine
    def put(self, item):
        """Put an item into the channel.
        If the channel is full, wait until a free
        slot is available before adding item.
        If the channel is closed or closing, raise ChannelClosed.
        This method is a coroutine.
        """
        while self.full() and not self._close.is_set():
            putter = Future(loop=self._loop)
            self._putters.append(putter)
            try:
                yield from putter
            except ChannelClosed:
                raise
            except:
                putter.cancel()  # Just in case putter is not done yet.
                if not self.full() and not putter.cancelled():
                    # We were woken up by get_nowait(), but can't take
                    # the call.  Wake up the next in line.
                    self._wakeup_next(self._putters)
                raise
        return self.put_nowait(item)

    def put_nowait(self, item):
        """Put an item into the channel without blocking.
        If no free slot is immediately available, raise ChannelFull.
        """
        if self.full():
            raise ChannelFull
        if self._close.is_set():
            raise ChannelClosed
        self._put(item)
        self._wakeup_next(self._getters)

    @coroutine
    def get(self):
        """Remove and return an item from the channel.
        If channel is empty, wait until an item is available.
        This method is a coroutine.
        """
        while self.empty() and not self._close.is_set():
            getter = Future(loop=self._loop)
            self._getters.append(getter)
            try:
                yield from getter
            except ChannelClosed:
                raise
            except:
                getter.cancel()  # Just in case getter is not done yet.
                if not self.empty() and not getter.cancelled():
                    # We were woken up by put_nowait(), but can't take
                    # the call.  Wake up the next in line.
                    self._wakeup_next(self._getters)
                raise
        return self.get_nowait()

    def get_nowait(self):
        """Remove and return an item from the channel.
        Return an item if one is immediately available, else raise ChannelEmpty.
        """
        if self.empty():
            if self._close.is_set():
                raise ChannelClosed
            else:
                raise ChannelEmpty
        item = self._get()
        if self.empty() and self._close.is_set():
            # if empty _after_ we retrieved an item AND marked for closing,
            # set the finished flag
            self._finished.set()
        self._wakeup_next(self._putters)
        return item

    @coroutine
    def join(self):
        """Block until channel is closed and channel is drained
        """
        yield from self._finished.wait()

    def close(self):
        """Marks the channel is closed and throw a ChannelClosed in all pending putters"""
        self._close.set()
        # cancel putters
        for putter in self._putters:
            putter.set_exception(ChannelClosed())
        # cancel getters that can't ever return (as no more items can be added)
        while len(self._getters) > self.qsize():
            getter = self._getters.pop()
            getter.set_exception(ChannelClosed())

        if self.empty():
            # already empty, mark as finished
            self._finished.set()

    def closed(self):
        """Returns True if the Channel is marked as closed"""
        return self._close.is_set()

    @coroutine
    def __aiter__(self):  # pragma: no cover
        """Returns an async iterator (self)"""
        return self

    @coroutine
    def __anext__(self):  # pragma: no cover
        try:
            data = yield from self.get()
        except ChannelClosed:
            raise StopAsyncIteration
        else:
            return data

    def __iter__(self):
        return iter(self._queue)