Beispiel #1
0
class Crawler:
    def __init__(self, root_url, max_redirect):
        self.max_tasks = 10
        self.max_redirect = max_redirect
        self.q = Queue()
        self.seen_urls = set()

        # aiohttp's ClientSession does connection pooling and
        # HTTP keep-alives for us.
        self.session = aiohttp.ClientSession(loop=loop)

        # Put (URL, max_redirect) in the queue.
        self.q.put((root_url, self.max_redirect))

    @asyncio.coroutine
    def crawl(self):
        """Run the crawler until all work is done."""
        workers = [asyncio.Task(self.work())
                   for _ in range(self.max_tasks)]

        # When all work is done, exit.
        yield from self.q.join()
        for w in workers:
            w.cancel()

    @asyncio.coroutine
    def work(self):
        while True:
            url, max_redirect = yield from self.q.get()

            # Download page and add new links to self.q.
            yield from self.fetch(url, max_redirect)
            # 把新links加入q后再task_done()
            self.q.task_done()
Beispiel #2
0
class Crawler:
    def __init__(self, root_url, max_redirect):
        self.max_tasks = 10
        self.max_redirect = max_redirect
        self.q = Queue()
        self.seen_urls = set()

        # aiohttp's ClientSession does connection pooling and
        # HTTP keep-alives for us.
        self.session = aiohttp.ClientSession(loop=loop)

        # Put (URL, max_redirect) in the Queue
        self.q.put((root_url, self.max_redirect))

    @asyncio.coroutine
    def crawl(self):
        '''Run the crawler untill all work is done.'''
        workers = [asyncio.Task(self.work()) for _ in range(self.max_tasks)]

        # When all work is done, exit.
        yield from self.q.join()
        for w in workers:
            w.cancel()

    @asyncio.coroutine
    def work(self):
        while True:
            url, max_redirect = yield from self.q.get()

            # Download page and add new links to self.q
            yield from self.fetch(url, max_redirect)
            self.q.task_done()

    @asyncio.coroutine
    def fetch(self, url, max_redirect):
        # Handle redirects ourselves.
        response = yield from self.session.get(url, allow_redirects=False)

        try:
            if is_redirect(response):
                if max_redirect > 0:
                    next_url = response.headers['location']
                    if next_url in self.seen_urls:
                        # We have done this before.
                        return

                    # Remember we have seen this url.
                    self.seen_urls.add(next_url)

                    # Follow the redirect. One less redirect remains.
                    self.q.put_nowait((next_url, max_redirect - 1))
            else:
                links = yield from self.parse_links(response)
                # Python set-logic:
                for link in links.difference(self.seen_urls):
                    self.q.put_nowait((link, self.max_redirect))
                self.seen_urls.update(links)
        finally:
            # Return connection to pool.
            yield from response.release()
class ModelThreadPool:
    def __init__(self):
        self.bs_in_queue = Queue()
        self.bs_out_queue = Queue()

        thread = threading.Thread(target=self.bubble_sorter, args=())
        thread.daemon = True
        thread.start()

        # Repeat the process for the rest of the threads

    def input_converter(self):
        pass

    def bubble_sorter(self):
        fn = lambda x: 1

        while True:
            if self.bs_queue.empty(): continue
            num_wires = self.bs_queue.get()
            self.bs_out_queue.put(fn(num_wires))

    def query_bubble_sorter(self, num_wires):
        self.bs_in_queue.put(num_wires)
        sleep(2)
        return self.bs_out_queue.get()
Beispiel #4
0
async def _merge_helper(q: Queue, end: Task, ait: AsyncIterable[Any]) -> None:
    ch = aiter(ait)

    while True:
        pending_take = create_task(cast(Any, ch.__anext__()))
        done_1, _ = await wait((end, pending_take),
                               return_when=FIRST_COMPLETED)

        if pending_take in done_1:
            try:
                item = await pending_take
            except StopAsyncIteration:
                break
            else:
                if end in done_1:
                    break
                else:
                    pending_put = create_task(q.put(item))
                    done_2, _ = await wait((end, pending_put),
                                           return_when=FIRST_COMPLETED)

                    if pending_put in done_2:
                        await pending_put

                    if end in done_2:
                        break

        if end in done_1:
            break
Beispiel #5
0
async def produce(conf: DotDict, queue: asyncio.Queue,
                  conn_queue: asyncio.Queue, logger) -> None:
    """
    Reads messages from the queue, build rows and insert them into PG table.

    Gracefully cancels async tasks and shutdown event loop in case of
    connectivity issues with PostgreSQL. It is assumed that service failures
    should be handled by container orchestration software (it is easy in
    Kubernetes).

    """
    # If return_exceptions=True, exceptions are treated the same as successful
    # results, and aggregated in the result list
    conn_or_exc = await asyncio.gather(connect_pg(conf, logger),
                                       return_exceptions=True)
    _handle_exc(conn_or_exc, conn_queue, logger)

    # Will execute if there were no exceptions at establishing connection
    conn = conn_or_exc[0]
    # Using a dedicated queue to share PG connection with `shutdown`
    # to properly clean up
    asyncio.create_task(conn_queue.put(conn))

    # Create table
    await conn.execute(_create_table_query(conf.pg_table_name))  # type: ignore

    while True:
        msg_bytes = await queue.get()
        msg = json.loads(msg_bytes)
        logger.info("writing metric to PostgreSQL")
        query = _compose_insert_query(conf.pg_table_name, msg)
        await conn.execute(query)  # type: ignore
        queue.task_done()
Beispiel #6
0
def search_async(q_list):
    search_queue = Queue()
    search = build_search(is_async=True)
    show_msg = False

    # loop through companies
    for q in q_list:
        search.params_dict["q"] = q
        data = search.get_dict()

        # add search to the search_queue
        search_queue.put(data)

        if show_msg:
            print("execute async search: q = " + q)
            print("add search to the queue where id: " +
                  data['search_metadata']['id'])
    print("wait until all search statuses are cached or success")
    # Create regular search
    search = GoogleSearch({"async": True})
    while not search_queue.empty():
        data = search_queue.get()
        search_id = data['search_metadata']['id']

        # retrieve search from the archive - blocker
        search_archived = search.get_search_archive(search_id)
        if show_msg:
            print(search_id + ": get search from archive")
            print(search_id + ": status = " +
                  search_archived['search_metadata']['status'])

        # check status
        if re.search('Cached|Success',
                     search_archived['search_metadata']['status']):
            if show_msg:
                print(search_id + ": search done with q = " +
                      search_archived['search_parameters']['q'])
            QUERY_RESULT[search_archived['search_parameters']['q']
                         [-5:]] = search_archived["organic_results"]
        else:
            # requeue search_queue
            print(search_id + ": requeue search")
            search_queue.put(search)
            # wait 1s
            time.sleep(1)
    # search is over.
    print('all searches completed')
Beispiel #7
0
def test_nonmatching():
    i_queue = Queue()
    o_queue = find_events(i_queue)

    for in_string in NONMATCHING_TESTS:
        yield from i_queue.put(in_string)

    yield from o_queue.get()
Beispiel #8
0
def test_matching():
    i_queue = Queue()
    o_queue = find_events(i_queue)

    for in_string, event in MATCHING_TESTS:
        yield from i_queue.put(in_string)
        ev = yield from o_queue.get()
        eq_(ev, event)
Beispiel #9
0
class AverageMessageHandlerForTest(AverageMemoryMessageHandler):
    def __init__(self, keys, average_period_minutes=0):
        super().__init__(keys, average_period_minutes)
        self.queue = Queue()

    @asyncio.coroutine
    def save(self, average_message):
        yield from self.queue.put(average_message)
Beispiel #10
0
class LogData:
    """ Handle incoming pkts on the ../data topic """
    def __init__(self, conf):
        self.conf = conf
        self.logfile = open(conf["file"], "a+")
        self.pkt_inq = Queue()
        self.nodes_online = 0

    def log_state(self, slid, new_state):
        logger.debug("logdata node %d %s", slid, new_state)
        self.nodes_online += 1 if new_state == "ONLINE" else -1

    def post_incoming(self, pkt):
        """ a pkt arrives """

        self.log_pkt(pkt, "receive")
        self.pkt_inq.put(pkt, "recv")

    def post_outgoing(self, pkt):
        """ a pkt is sent """
        self.log_pkt(pkt, "send")

    def log_pkt(self, pkt, direction):
        self.logfile.write(pkt.json() + "\n")
        self.logfile.flush()

    def wait_pkt_number(self, pktnumber, timeout, num_packets):
        """ wait for pkt with number pktnumber for a max of timeout seconds """
        lwait = timeout
        packets_seen = 0
        while True:
            now = time.time()
            try:
                test_pkt = self.pkt_inq.get(block=True, timeout=lwait)
                packets_seen += 1
            except Empty:
                test_pkt = None
            logger.debug("wait_pkt_number pkt %s", test_pkt)
            waited = time.time() - now
            if test_pkt and test_pkt.pkt[
                    'pktno'] == pktnumber and packets_seen == num_packets:
                return pktnumber
            if waited >= lwait or test_pkt.pkt[
                    'pktno'] > pktnumber:  # our pkt will never arrive
                return None
            lwait -= waited
Beispiel #11
0
class AbstractQueue(AbstractDataset, ABC):
    def __init__(self, maxsize=3, *args):
        super().__init__(*args)

        # initialise queue
        self.queue = Queue(maxsize=maxsize)
        return

    # ********************************************** #

    def __iter__(self):
        # self.background_worker = Thread(target=self.enqueue(), daemon=True)
        # self.background_worker.start()
        return self

    # ********************************************** #

    def __next__(self):
        if not self.queue.empty():
            batch = self.queue.get()
            self.queue.task_done()
            return batch
        else:
            raise StopIteration

    async def request_sample(self, idx):
        start = time.time()
        print("Fetching new sample...", end="\r")
        sample = await self.samples[idx]
        print("Fetched in %4f" % (time.time() - start))
        return sample

    def _producer(self):
        # this is async
        for frame in self.frames:
            t = Thread(target=self.request_sample, args=frame)
            self.queue.put(t)
        self.queue.join()
        return

    def print_random(self):
        idx = rn.randint(0, len(self))
        sample = self.request_sample(idx)
        print(f"Example batch {idx:s}")
        print(sample)
Beispiel #12
0
class MessageHandler(ws.WS):
    def __init__(self):
        self.queue = Queue()

    def get(self):
        return self.queue.get()

    def on_message(self, websocket, message):
        return self.queue.put(message)
Beispiel #13
0
async def run():
    q = Queue()
    await asyncio.wait([q.put(i) for i in range(10)])
    tasks = [asyncio.ensure_future(work(q))]
    print('wait join')
    await q.join()
    print('end join')
    for task in tasks:
        task.cancel()
Beispiel #14
0
def _handle_exc(
    results: Tuple[Union[Any, BaseException]],
    conn_queue: asyncio.Queue,
    logger,
):
    for res in results:
        if isinstance(res, Exception):
            logger.error("failed to connect to PostgreSQL", error=res)
            asyncio.create_task(conn_queue.put(None))
            raise res  # trigger global exception handler
Beispiel #15
0
def new_queue():
    global _main_loop

    queue = Queue(loop=_main_loop)

    def putter(item):
        _main_loop.call_soon_threadsafe(queue.put_nowait, item)

    queue.put = putter
    return queue
Beispiel #16
0
async def third():
    await first()
    myqueue = Queue()
    q = Queue()
    await asyncio.wait([q.put(i) for i in range(10)])
    tasks = [asyncio.ensure_future(work(myqueue))]
    await myqueue.join()
    print(3)
    for task in tasks:
        task.cancel()
Beispiel #17
0
def new_queue():
    global _main_loop

    queue = Queue(loop=_main_loop)

    def putter(item):
        _main_loop.call_soon_threadsafe(queue.put_nowait, item)

    queue.put = putter
    return queue
Beispiel #18
0
async def run():
    q = Queue()
    await asyncio.wait([q.put(i) for i in range(10)])	# 协程1,因为Queue的put为Coroutine,所以写入队列的数字不一定按顺序
    print("[q]",q)
    tasks = [asyncio.ensure_future(work(q))]	# 协程2,生成一个协程work对象处理列表Queue
    print("[tasks]", tasks)
    print('wait join')
    await q.join()	# 任务同步等待,若不等待,则会直接执行task.cancel(),导致没有任务执行。
    print('end join')
    for task in tasks:
        task.cancel()
Beispiel #19
0
 async def connect(q: asyncio.Queue):
     try:
         async with websockets.connect("ws://127.0.0.1:8000/ws") as ws:
             json = {
                 "command": "start",
                 "symbol": "btcusdt",
             }
             ws.send(json)
             data = await ws.recv()
             q.put(data)
             json = {
                 "command": "stop",
                 "symbol": "btcusdt",
             }
             print(data)
             assert data is not None, "no data "
             assert data.get(
                 "symbol") is not None, "Received data wrong format"
     except KeyboardInterrupt:
         print("Canceled")
Beispiel #20
0
async def listen(task_id: str, q: asyncio.Queue) -> None:
    """Listen kafka

    :param task_id: task_id
    :type task_id: str
    :param q: queue for consumer
    :type q: asyncio.Queue
    """
    consumer = AIOKafkaConsumer(
        'data', 'finished',
        loop=asyncio.get_event_loop(),
        bootstrap_servers='10.199.13.36:9091, 10.199.13.37:9092, 10.199.13.38:9093',
        value_deserializer=lambda m: json.loads(m.decode('utf8')))
    await consumer.start()
    async for msg in consumer:
        if msg.topic == 'data' and msg.value.get('task_id') == task_id and msg.value.get('data'):
            item = preprocess(item.get('data'), {})
            q.put(item['data'])
        elif msg.topic == 'finished' and msg.value.get('task_id') == task_id:
            break
    await consumer.stop()
Beispiel #21
0
async def run_more(file_list, file_workers, block_size, one_workers):
    """同时下载下载多个文件"""
    files_queues = Queue()
    await asyncio.wait([files_queues.put(i) for i in file_list])
    tasks = [
        asyncio.ensure_future(
            worker_more(files_queues, block_size, one_workers))
        for i in range(file_workers)
    ]
    await files_queues.join()
    for task in tasks:
        task.cancel()
Beispiel #22
0
    def bfs(self, vertex, visited, processing, collection):
        '''breadth-first search
		vertex: current vertex to be visited
		visited: recording of vistied vertices
		procesing: a function to process a vertex and return a value
		collection: a list of the values returned by processing
		'''
        q = Queue()
        visited[vertex.id] = 1
        collection.append(processing(vertex))
        q.put(vertex)
        while (not q.empty()):
            v0 = q.get()
            for e in v0.edges:
                v1 = e.vertices[1] if e.vertices[0] == v0 \
                 else e.vertices[0]

                if visited[v1.id] == 0:
                    visited[v1.id] = 1
                    collection.append(processing(v1))
                    q.put(v1)
def breadthfirst(bt):
    """breadthfirst: binary tree -> list[Node]
    Purpose: Runs a breadth first search on a binary tree
    Consumes: a binary tree object
    Produces: a list of Nodes in breadth first search order
    Example: 
                    A 
    breadthfirst(  / \  ) -> [A B C]
                  B   C 
    If tree is empty, should return an empty list. If the tree
    is null, you should throw InvalidInputException. 
    """
    if bt is None:
        raise InvalidInputException("Input is None")
    if bt.isEmpty():
        return []

    Q = Queue()
    qlist = []
    qlist.append(bt.root())
    Q.put(bt.root())

    while not Q.empty():

        node = Q.get()

        if bt.hasLeft(node):
            Q.put(bt.left(node))
            qlist.append(bt.left(node))
        if bt.hasRight(node):
            Q.put(bt.right(node))
            qlist.append(bt.right(node))

    return qlist
Beispiel #24
0
class Crawler:
    def __init__(self, root_url):
        self.max_tasks = 10
        self.q = Queue()
        self.seen_urls = set()

        self.q.put(root_url)

    async def crawl(self):
        self.session = aiohttp.ClientSession(loop=loop)
        workers = [asyncio.Task(self.work()) for _ in range(self.max_tasks)]

        await self.q.join()
        for w in workers:
            w.cancel()
        self.sesson.close()

    async def work(self):
        while True:
            url = await self.q.get()
            await self.fetch(url)
            self.q.task_done()

    async def fetch(self, url):
        response = await self.session.get(url)
        links = await self.parse_links(reponse)
        for link in links.difference(self.seen_urls):
            self.q.put_nowait(link)
        self.seen_urls.update(links)
        print(response)

    async def parse_links(self, response):
        soup = BeautifulSoup(response, 'html.parser')
        anchors = soup.find_all('a')
        links = []
        for anchor in anchors:
            if anchor.get('href'):
                links.append(anchor['href'])
        return links
Beispiel #25
0
class L2TTY:
    def __init__(self, local_call: str, remote_call: str,
                 datalink: DataLinkManager):
        self.local_call = AX25Call.parse(local_call)
        self.remote_call = AX25Call.parse(remote_call)
        self.dl = datalink
        self.stdin_queue = Queue()
        self.connected = False
        self.circuit_id = None
        EventBus.bind(
            EventListener(f"link.{local_call}.connect",
                          f"link_{local_call}_connect", self.handle_connect))
        EventBus.bind(
            EventListener(f"link.{local_call}.disconnect",
                          f"link_{local_call}_disconnect",
                          self.handle_disconnect))
        EventBus.bind(
            EventListener(f"link.{local_call}.inbound",
                          f"link_{local_call}_inbound", self.handle_data))

    async def start(self):
        while True:
            next_stdin = await self.stdin_queue.get()
            if next_stdin is not None:
                if not self.connected:
                    print("connecting...")
                    self.dl.dl_connect_request(self.remote_call)
                else:
                    await self.dl.dl_data_request(self.remote_call,
                                                  L3Protocol.NoLayer3,
                                                  next_stdin.encode("utf-8"))
                self.stdin_queue.task_done()

    def handle_connect(self, remote_call: AX25Call):
        print(f"Connected to {remote_call} L2")
        sys.stdout.write("> ")
        self.connected = True

    def handle_disconnect(self, remote_call: AX25Call):
        print(f"Disconnected from {remote_call} L2")
        self.connected = False
        graceful_shutdown()

    def handle_stdin(self):
        line = sys.stdin.readline().strip()
        asyncio.create_task(self.stdin_queue.put(line))

    def handle_data(self, remote_call: AX25Call, protocol: L3Protocol,
                    data: bytes):
        msg = str(data, 'utf-8')
        sys.stdout.write(msg)
Beispiel #26
0
class ConcurrentQueue:
    def __init__(self, max_size=10):
        self.max_size = max_size  # 满仓量
        self.lock = Lock()  # 互斥对象
        self.cond = Condition(self.lock)
        self.q = Queue()  # 数据对象(仓库)

# 获取数据(消费)

    def get(self):  # 获取队列的数据
        # 获取互斥锁和条件变量(默认包含互斥量)
        if self.cond.acquire():
            # 仓库为空时,无法满足消费者
            while self.q.empty():
                print('仓库已空,请等待...')
                self.cond.wait()  # 条件变量等待

            # 仓库不为空时
            obj = self.q.get()  # 获取要消息数据
            self.cond.notify()  # 通知等待的生产者线程,开始消费了...
            self.cond.release()  # 释放锁

        return obj

# 存入数据(生产)

    def put(self, obj):
        if self.cond.acquire():
            # 仓库为满时,无法再生产
            while self.q.qsize() >= self.max_size:
                print('仓库已满,请等待生产...')
                self.cond.wait()  # 条件变量等待

            # 仓库未满,可以再生产
            self.q.put(obj)

            self.cond.notify()  # 通过等待的消费者线程,已生产了
            self.cond.release()  # 释放锁
Beispiel #27
0
async def read_messages_task(reader: asyncio.StreamReader,
                             received_msg_queue: asyncio.Queue,
                             download_stats: PeerConnectionStats = None):
    """
    Coroutine intended to be scheduled as a task that will continually
    read messages from the peer and populate them into the given queue.

    :param reader: `StreamReader` to read from.
    :param received_msg_queue: `Queue` to place messages into.
    :param download_stats: Optional `PeerConnectionStatus` object to populate stats into.
    """
    while True:
        received = await _receive_from_peer(reader, download_stats)
        if received:
            asyncio.create_task(received_msg_queue.put(received))
Beispiel #28
0
async def run_one(url, size, one_workers, block_size, temp_folder,
                  is_continue):
    """分配创建任务,将一个文件分块下载"""
    block_queues = Queue()
    await asyncio.wait([
        block_queues.put(i)
        for i in make_block(size, block_size, temp_folder, is_continue)
    ])
    tasks = [
        asyncio.ensure_future(worker_one(block_queues, url, temp_folder))
        for i in range(one_workers)
    ]
    await block_queues.join()
    for task in tasks:
        task.cancel()
def bfs(self):
    queue = Queue()
    queue.put(self)

    while not queue.empty():
        current_node = queue.get()
        print(current_node.value)

        if current_node.left_child:
            queue.put(current_node.left_child)

        if current_node.right_child:
            queue.put(current_node.right_child)
Beispiel #30
0
async def page_monitor(client: AsyncClient, conf: DotDict,
                       queue: asyncio.Queue, logger) -> None:
    """
    Collect webpage availability metrics.

    Metrics are placed into the queue for subsequent processing.
    In case of connectivity failures to Kafka Broker, retries till
    connection is available again.

    Network connctivity issues are mitigated by exponential backoff
    with configurable amount of retries.

    """
    # Configure exponential backoff without jitter;
    # no competing clients, as described here:
    # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
    backoff_deco = backoff.on_exception(
        backoff.expo,
        httpx.TransportError,
        on_backoff=_backoff_handler,
        max_tries=conf.backoff_retries,
        jitter=None,
    )

    while True:
        # Ping webpage
        resp = await backoff_deco(client.get)(conf.page_url)
        http_code = resp.status_code
        resp_time = resp.elapsed

        # Compose Kafka message
        msg = {
            # Following xkcd.com/1179, sorry ISO 8601
            "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "page_url": conf.page_url,
            "http_code": http_code,
            "response_time": resp_time.microseconds,
        }
        logger.info(source="monitor", message=msg)

        asyncio.create_task(queue.put(msg))
        await asyncio.sleep(conf.ping_interval)
Beispiel #31
0
class Test:
    def __init__(self):
        self.que = Queue()
        self.pue = Queue()

    async def consumer(self):
        while True:
            try:
                print('consumer', await self.que.get())
            finally:
                try:
                    self.que.task_done()
                except ValueError:
                    if self.que.empty():
                        print("que empty")

    async def work(self):
        while True:
            try:
                value = await self.pue.get()
                print('producer', value)
                await self.que.put(value)
            finally:
                try:
                    self.pue.task_done()
                except ValueError:
                    if self.pue.empty():
                        print("pue empty")

    async def run(self):
        tasks = [asyncio.ensure_future(self.work()),
                 asyncio.ensure_future(self.consumer())]

        await asyncio.wait([self.pue.put(i) for i in range(10)])

        print('p queue join')
        await self.pue.join()
        print('p queue is done & q queue join')
        await self.que.join()
        print('q queue is done')

        asyncio.gather(*tasks).cancel()
Beispiel #32
0
    def fill_peer_request_queue(self, peer: PeerInfo,
                                msg_queue: asyncio.Queue) -> bool:
        """
        Fills the given queue with up to 10 new requests for the peer, returning
        True if more requests were added or False otherwise.

        :param peer: The peer asking for a top up
        :param msg_queue: the message queue to place the requests into
        :return: True if more requests were added or the peer has any outstanding.
        """
        added_more = False
        num_needed = 10 - len(self._peer_unfulfilled_requests[peer])

        for _ in range(num_needed):
            request = self.next_request_for_peer(peer)
            if not request:  # no more requests for this peer
                break
            asyncio.create_task(msg_queue.put(request))
            added_more = True
        return added_more
Beispiel #33
0
class Test:
    def __init__(self):
        self.que = Queue()
        self.pue = Queue()

    async def consumer(self):
        while True:
            try:
                print('consumer', await self.que.get())
            finally:
                try:
                    self.que.task_done()  # 告知队列-1,不然join无法判断队列是否执行完毕
                except ValueError:  # 如果队列为空,再次调用task_done则会报错,ValueError('task_done() called too many times'),这里先捕获了,并判断队列是否为空
                    if self.que.empty():
                        print("que empty")

    async def work(self):
        while True:
            try:
                value = await self.pue.get()
                print('producer', value)
                await self.que.put(value)
            finally:
                try:
                    self.pue.task_done()  # 告知队列-1,不然join无法判断队列是否执行完毕
                except ValueError:  # 如果队列为空,再次调用task_done则会报错,ValueError('task_done() called too many times'),这里先捕获了,并判断队列是否为空
                    if self.pue.empty():
                        print("pue empty")

    async def run(self):
        await asyncio.wait([self.pue.put(i) for i in range(10)])
        tasks = [asyncio.ensure_future(self.work())]
        tasks.append(asyncio.ensure_future(self.consumer()))
        print('p queue join')
        await self.pue.join()
        print('p queue is done & q queue join')
        await self.que.join()
        print('q queue is done')
        for task in tasks:
            task.cancel()
Beispiel #34
0
class MQTTClientProtocol(FlowControlMixin, asyncio.Protocol):

    def __init__(self, loop, config):
        super().__init__(loop=loop)
        self._loop = loop
        self._config = config

        self._transport = None
        self._write_pending_data_topic = []     # tuple (data, topic)
        self._connected = False

        self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method'])

        self._peername = None

        self._reader_task = None
        self._data_task = None
        self._keepalive_task = None
        self._keepalive_timeout = self._config['timeout']
        self._reader_ready = None
        self._reader_stopped = asyncio.Event(loop=self._loop)
        self._stream_reader = StreamReader(loop=self._loop)
        self._stream_writer = None
        self._reader = None

        self._topic_to_clients = {}

        self._queue = Queue(loop=loop)

    async def create_connection(self):
        try:
            # TODO handle pending task
            transport, protocol = await self._loop.create_connection(lambda: self, self._config['address'], self._config['port'])
        except OSError as e:
            logging.error("{0} when connecting to mqtt server({1}:{2})".format(e, self._config['address'], self._config['port']))
            logging.error("Reconnection will be performed after 5s...")
            await asyncio.sleep(5)     # TODO:retry interval
            self._loop.create_task(self.create_connection())

    def connection_made(self, transport):
        self._peername = transport.get_extra_info('peername')
        self._transport = transport

        self._stream_reader.set_transport(transport)
        self._reader = StreamReaderAdapter(self._stream_reader)
        self._stream_writer = StreamWriter(transport, self,
                                           self._stream_reader,
                                           self._loop)
        self._loop.create_task(self.start())

    def connection_lost(self, exc):
        logging.info("Lost connection with mqtt server{0}".format(self._peername))
        super().connection_lost(exc)
        self._topic_to_clients = {}

        if self._stream_reader is not None:
            if exc is None:
                self._stream_reader.feed_eof()
            else:
                self._stream_reader.set_exception(exc)

        self.stop()

        self.reestablish_connection()

    def reestablish_connection(self):
        self._stream_reader = StreamReader(loop=self._loop)
        self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method'])
        self._loop.call_later(5, lambda: self._loop.create_task(self.create_connection()))

    def data_received(self, data):
        self._stream_reader.feed_data(data)

    def eof_received(self):
        self._stream_reader.feed_eof()

    @asyncio.coroutine
    def consume(self):
        while self._transport is not None:
            packet = yield from self._queue.get()
            if packet is None:
                break

            if self._transport is None:
                break
            yield from self._send_packet(packet)

    @asyncio.coroutine
    def start(self):
        self._reader_ready = asyncio.Event(loop=self._loop)
        self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop)
        yield from self._reader_ready.wait()
        if self._keepalive_timeout:
            self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

        self._data_task = self._loop.create_task(self.consume())

        # send connect packet
        connect_vh = ConnectVariableHeader(keep_alive=self._keepalive_timeout)
        connect_vh.password_flag = True
        password = self._encryptor.encrypt(self._encryptor.password.encode('utf-8'))
        connect_payload = ConnectPayload(client_id=ConnectPayload.gen_client_id(), password=password)
        connect_packet = ConnectPacket(vh=connect_vh, payload=connect_payload)
        yield from self._do_write(connect_packet)

        logging.info("Creating connection to mqtt server.")

    @asyncio.coroutine
    def stop(self):
        self._connected = False
        if self._keepalive_task:
            self._keepalive_task.cancel()
        self._data_task.cancel()
        logger.debug("waiting for tasks to be stopped")
        if not self._reader_task.done():

            if not self._reader_stopped.is_set():
                self._reader_task.cancel()  # this will cause the reader_loop handle CancelledError
                # yield from asyncio.wait(
                #     [self._reader_stopped.wait()], loop=self._loop)
            else:   # caused by reader_loop break statement
                if self._transport:
                    self._transport.close()
                    self._transport = None

    @asyncio.coroutine
    def _reader_loop(self):
        running_tasks = collections.deque()
        while True:
            try:
                self._reader_ready.set()
                while running_tasks and running_tasks[0].done():
                    running_tasks.popleft()
                if len(running_tasks) > 1:
                    logging.debug("{} Handler running tasks: {}".format(self._peername, len(running_tasks)))

                fixed_header = yield from asyncio.wait_for(
                    MQTTFixedHeader.from_stream(self._reader),
                    self._keepalive_timeout + 10, loop=self._loop)
                if fixed_header:
                    if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15:
                        logging.warning("{} Received reserved packet, which is forbidden: closing connection".format(self._peername))
                        break
                    else:
                        cls = packet_class(fixed_header)
                        packet = yield from cls.from_stream(self._reader, fixed_header=fixed_header)
                        task = None
                        if packet.fixed_header.packet_type == CONNACK:
                            task = ensure_future(self.handle_connack(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGREQ:
                            task = ensure_future(self.handle_pingreq(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGRESP:
                            task = ensure_future(self.handle_pingresp(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PUBLISH:
                            # task = ensure_future(self.handle_publish(packet), loop=self._loop)
                            self.handle_publish(packet)
                        # elif packet.fixed_header.packet_type == SUBSCRIBE:
                        #     task = ensure_future(self.handle_subscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBSCRIBE:
                        #     task = ensure_future(self.handle_unsubscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == SUBACK:
                        #     task = ensure_future(self.handle_suback(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBACK:
                        #     task = ensure_future(self.handle_unsuback(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == DISCONNECT:
                            task = ensure_future(self.handle_disconnect(packet), loop=self._loop)
                        else:
                            logging.warning("{} Unhandled packet type: {}".format(self._peername, packet.fixed_header.packet_type))
                        if task:
                            running_tasks.append(task)
                else:
                    logging.debug("{} No more data (EOF received), stopping reader coro".format(self._peername))
                    break
            except MQTTException:
                logging.debug("{} Message discarded".format(self._peername))
            except asyncio.CancelledError:
                # logger.debug("Task cancelled, reader loop ending")
                break
            except asyncio.TimeoutError:
                logging.debug("{} Input stream read timeout".format(self._peername))
                break
            except NoDataException:
                logging.debug("{} No data available".format(self._peername))
            except BaseException as e:
                logging.warning(
                    "{}:{} Unhandled exception in reader coro: {}".format(type(self).__name__, self._peername, e))
                break
        while running_tasks:
            running_tasks.popleft().cancel()
        self._reader_stopped.set()
        logging.debug("{} Reader coro stopped".format(self._peername))
        yield from self.stop()

    def write(self, data: bytes, topic):
        if not self._connected:
            self._write_pending_data_topic.append((data, topic))
            if len(self._write_pending_data_topic) > 50:
                self._write_pending_data_topic.clear()
        else:
            data = self._encryptor.encrypt(data)
            packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0)
            ensure_future(self._do_write(packet), loop=self._loop)

    def write_eof(self, topic):
        packet = PublishPacket.build(topic, b'', None, dup_flag=0, qos=0, retain=1)
        ensure_future(self._do_write(packet), loop=self._loop)

    @asyncio.coroutine
    def _do_write(self, packet):
        yield from self._queue.put(packet)

    @asyncio.coroutine
    def _send_packet(self, packet):
        try:
            yield from packet.to_stream(self._stream_writer)
        except ConnectionResetError:
            return

        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

    def handle_write_timeout(self):
        packet = PingReqPacket()
        # TODO: check transport
        self._transport.write(packet.to_bytes())
        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

    def handle_read_timeout(self):
        self._loop.create_task(self.stop())

    @asyncio.coroutine
    def handle_connack(self, connack: ConnackPacket):
        if connack.variable_header.return_code == 0:
            self._connected = True
            logging.info("Connection to mqtt server established!")

            if len(self._write_pending_data_topic) > 0:
                self._keepalive_task.cancel()
                for data, topic in self._write_pending_data_topic:
                    data = self._encryptor.encrypt(data)
                    packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0)
                    yield from self._do_write(packet)
                self._write_pending_data_topic = []
                self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)
        else:
            logging.info("Unable to create connection to mqtt server! Shuting down...")
            self._loop.create_task(self.stop())

    # @asyncio.coroutine
    def handle_publish(self, publish_packet: PublishPacket):
        data = bytes(publish_packet.data)

        server = self._topic_to_clients.get(publish_packet.topic_name, None)
        if server is None:
            logging.info("Received unregistered publish topic({0}) from mqtt server, packet will be ignored.".format(
                publish_packet.topic_name))
        if not publish_packet.retain_flag:    # retain=1 indicate we should close the client connection
            data = self._encryptor.decrypt(data)
            if server is not None:
                server.write(data)
        else:
            if server is not None:
                server.close(force=True)

    @asyncio.coroutine
    def handle_pingresp(self, pingresp: PingRespPacket):
        logging.info("Received PingRespPacket from mqtt server.")

    @asyncio.coroutine
    def handle_pingreq(self, pingreq: PingReqPacket):
        logging.info("Received PingReqPacket from mqtt server, Replying PingResqPacket.")
        ping_resp = PingRespPacket()
        yield from self._do_write(ping_resp)

    def register_client_topic(self, topic, server):
        self._topic_to_clients[topic] = server

    def unregister_client_topic(self, topic):
        self._topic_to_clients.pop(topic, None)
Beispiel #35
0
class Messagedispatcher:
    def __init__(self, communicator):
        self.communicator = communicator
        self.messages = {
            "direct": {
                "status": {
                    "class": messages.StatusDirect,
                    "queue": Queue()
                },
                "pinor": {
                    "class": messages.PinorDirect,
                    "queue": Queue()
                }
            },
            "mesh": {
                "status": {
                    "class": messages.StatusMesh,
                    "queue": Queue()
                },
                "pinor": {
                    "class": messages.PinorMesh,
                    "queue": Queue()
                },
                "return": {
                    "class": messages.ReturnMesh,
                    "queue": Queue()
                },
                "deploy": {
                    "class": messages.DeployMesh,
                    "queue": Queue()
                },
                "grid": {
                    "class": messages.GridMesh,
                    "queue": Queue()
                }
            }
        }
        self.mesh_queue = Queue()
    @coroutine
    def wait_for_message(self, *types):
        x = self.messages
        for i in types:
            x = x[i]
        q = x["queue"]
        return (yield from q.get())
    @coroutine
    def get_mesh_message(self):
        return (yield from self.mesh_queue.get())
    @coroutine
    def startup(self):
        while True:
            meshput = False
            msg = yield from self.communicator.receive()
            if msg["type"] == "mesh":
                meshput = True
            x = self.messages
            x = x[msg["type"]]
            x = x[msg["data"]["datatype"]]
            q = x["queue"]
            c = x["class"]
            emsg = c.from_json(msg)
            yield from q.put(emsg)
            if meshput:
                # print("RECEIVE:  " + str(msg) + "\n")
                yield from self.mesh_queue.put(emsg)
Beispiel #36
0
class BrokerProtocolHandler(ProtocolHandler):
    def __init__(self, plugins_manager: PluginManager, session: Session=None, loop=None):
        super().__init__(plugins_manager, session, loop)
        self._disconnect_waiter = None
        self._pending_subscriptions = Queue(loop=self._loop)
        self._pending_unsubscriptions = Queue(loop=self._loop)

    @asyncio.coroutine
    def start(self):
        yield from super().start()
        if self._disconnect_waiter is None:
            self._disconnect_waiter = futures.Future(loop=self._loop)

    @asyncio.coroutine
    def stop(self):
        yield from super().stop()
        if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
            self._disconnect_waiter.set_result(None)

    @asyncio.coroutine
    def wait_disconnect(self):
        return (yield from self._disconnect_waiter)

    def handle_write_timeout(self):
        pass

    def handle_read_timeout(self):
        if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
            self._disconnect_waiter.set_result(None)

    @asyncio.coroutine
    def handle_disconnect(self, disconnect):
        self.logger.debug("Client disconnecting")
        if self._disconnect_waiter and not self._disconnect_waiter.done():
            self.logger.debug("Setting waiter result to %r" % disconnect)
            self._disconnect_waiter.set_result(disconnect)

    @asyncio.coroutine
    def handle_connection_closed(self):
        yield from self.handle_disconnect(None)

    @asyncio.coroutine
    def handle_connect(self, connect: ConnectPacket):
        # Broker handler shouldn't received CONNECT message during messages handling
        # as CONNECT messages are managed by the broker on client connection
        self.logger.error('%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling' %
                          (self.session.client_id, format_client_message(self.session)))
        if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
            self._disconnect_waiter.set_result(None)

    @asyncio.coroutine
    def handle_pingreq(self, pingreq: PingReqPacket):
        yield from self._send_packet(PingRespPacket.build())

    @asyncio.coroutine
    def handle_subscribe(self, subscribe: SubscribePacket):
        subscription = {'packet_id': subscribe.variable_header.packet_id, 'topics': subscribe.payload.topics}
        yield from self._pending_subscriptions.put(subscription)

    @asyncio.coroutine
    def handle_unsubscribe(self, unsubscribe: UnsubscribePacket):
        unsubscription = {'packet_id': unsubscribe.variable_header.packet_id, 'topics': unsubscribe.payload.topics}
        yield from self._pending_unsubscriptions.put(unsubscription)

    @asyncio.coroutine
    def get_next_pending_subscription(self):
        subscription = yield from self._pending_subscriptions.get()
        return subscription

    @asyncio.coroutine
    def get_next_pending_unsubscription(self):
        unsubscription = yield from self._pending_unsubscriptions.get()
        return unsubscription

    @asyncio.coroutine
    def mqtt_acknowledge_subscription(self, packet_id, return_codes):
        suback = SubackPacket.build(packet_id, return_codes)
        yield from self._send_packet(suback)

    @asyncio.coroutine
    def mqtt_acknowledge_unsubscription(self, packet_id):
        unsuback = UnsubackPacket.build(packet_id)
        yield from self._send_packet(unsuback)

    @asyncio.coroutine
    def mqtt_connack_authorize(self, authorize: bool):
        if authorize:
            connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED)
        else:
            connack = ConnackPacket.build(self.session.parent, NOT_AUTHORIZED)
        yield from self._send_packet(connack)

    @classmethod
    @asyncio.coroutine
    def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
        """

        :param reader:
        :param writer:
        :param plugins_manager:
        :param loop:
        :return:
        """
        remote_address, remote_port = writer.get_peer_info()
        connect = yield from ConnectPacket.from_stream(reader)
        yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
        if connect.payload.client_id is None:
            raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )

        if connect.variable_header.will_flag:
            if connect.payload.will_topic is None or connect.payload.will_message is None:
                raise MQTTException('will flag set, but will topic/message not present in payload')

        if connect.variable_header.reserved_flag:
            raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
        if connect.proto_name != "MQTT":
            raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.proto_name)

        connack = None
        error_msg = None
        if connect.proto_level != 4:
            # only MQTT 3.1.1 supported
            error_msg = 'Invalid protocol from %s: %d' % \
                              (format_client_message(address=remote_address, port=remote_port), connect.proto_level)
            connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION)  # [MQTT-3.2.2-4] session_parent=0
        elif not connect.username_flag and connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and not connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and connect.username is None:
            error_msg = 'Invalid username from %s' % \
                              (format_client_message(address=remote_address, port=remote_port))
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.password_flag and connect.password is None:
            error_msg = 'Invalid password %s' % (format_client_message(address=remote_address, port=remote_port))
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.clean_session_flag is False and (connect.payload.client_id is None or connect.payload.client_id == ""):
            error_msg = '[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % \
                              format_client_message(address=remote_address, port=remote_port)
            connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
        if connack is not None:
            yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
            yield from connack.to_stream(writer)
            yield from writer.close()
            raise MQTTException(error_msg)

        incoming_session = Session(loop)
        incoming_session.client_id = connect.client_id
        incoming_session.clean_session = connect.clean_session_flag
        incoming_session.will_flag = connect.will_flag
        incoming_session.will_retain = connect.will_retain_flag
        incoming_session.will_qos = connect.will_qos
        incoming_session.will_topic = connect.will_topic
        incoming_session.will_message = connect.will_message
        incoming_session.username = connect.username
        incoming_session.password = connect.password
        if connect.keep_alive > 0:
            incoming_session.keep_alive = connect.keep_alive
        else:
            incoming_session.keep_alive = 0

        handler = cls(plugins_manager, loop=loop)
        return handler, incoming_session
Beispiel #37
0
class BasePlugin(metaclass=ABCMeta):
    '''Core plug-in functionality

    A Sphinx plug-in needs to provide a minimim set of services in order to be
    useful.  Those are defined here, with default implementations where it
    makes sense.
    '''

    # This is a handle to the data bus.  It's set when we are registered.
    _databus = None

    # Type manager handle
    _tm = None

    def __init__(self, runner, plugins, source = None):
        '''Constructor

        This is how our plugin pipeline is constructed.  Each plugin instance
        is created when the input script is read, and they are chained together,
        from source to sink, here.

        This method _must_ be called with the event loop from which it will be
        called in the future, e.g., asyncio.get_event_loop().
        '''

        # A dict that maps each destination for our data, to the type that the
        # destination can consume.
        self._sinks = {}

        # Retain a pointer to our source, and add ourself to it's list of sinks.
        self._source = source
        if source:
            # Validate that we can process data from this source
            sink_types = set(source.sources()).intersection(self.sinks())
            if len(sink_types):
                source._set_sink(self, sink_types.pop())
                
            else:
                err = "{} cannot sink '{}'".format(self, source.sources())
                _log.error(err)
                raise ImpedenceMismatchError(err)

        # Our input queue
        self._queue = Queue()

        self.runner = runner
        self._plugins = plugins

        # create_task schedules the execution of the coroutine "run", wrapped
        # in a future.
        self._task = self.runner.create_task(self.run())


    def __getattr__(self, name):
        '''Plugin Pipeline Bulding

        This method is called when Python can't find a requested attribute. We
        use it to create a new plugin instance to add to the pipeline.
        '''
        if name in self._plugins:
            return partial(self._plugins[name], source = self)

        else:
            raise AttributeError


    def _set_sink(self, sink, data_type):
        '''Register a sink

        Called during initialization to register a sink (destination for our
        output).
        '''
        self._sinks[sink] = data_type
        

    @coroutine
    def publish(self, data):
        '''Publish data

        Called by a plugin to publish data to it's sinks.
        '''
        for sink, data_type in self._sinks.items():
            # Special case 'None', since that's our 'eof'.  See the 'done'
            # method below.
            if data:
                data = self.xform_data(data, data_type)
            yield from self._databus.publish(data, sink)


    @coroutine
    def write_data(self, data):
        '''Write data to queue
        
        Called by the databus controller to enqueue data from our source.
        '''
        yield from self._queue.put(data)
        

    @coroutine
    def read_data(self):
        '''Read data from queue

        Called by plugins to get data from their sources.
        '''
        payload = yield from self._queue.get()
        return payload
        

    @coroutine
    def done(self):
        '''The plugin is finished

        Called by a plugin to indicate to it's sinks that it has no more data.
        '''
        # TODO: It feels clumsy to use getting "None" as "EOT".  Also, it
        # requires that the plugins test for it to stop reading data.
        yield from self.publish(None)


        
    # Sources and sinks, oh my!  These follow the current flow analogy.
    # Data flows from a source to a sink.  Our input comes from a source,
    # and we sink it, process the data in some manner, and then source
    # it to the next plugin in the pipeline.
    @classmethod
    def sinks(cls):
        '''Sink types

        These are an array of types that we sink, i.e., read.
        '''
        return []


    @classmethod
    def sources(cls):
        '''Source types

        These are an array of types that we source, i.e., write.
        '''
        return []


    @classmethod
    def set_databus(cls, db):
        '''A handler to the Semantic Databus

        This gets set when the plug-in is registered.
        '''
        cls._databus = db
        cls._tm = db._typemgr


    @classmethod
    def script_name(cls):
        '''Return the plug-in's script name.

        The script name is how the plug-in is referred to by command scripts.
        '''
        pass


    @abstractmethod
    def xform_data(self, data, to_type):
        '''Transform data to a specific type

        This method must be able to transform the input, 'data', to the 'to_type'.
        The plugin will only be responsible for transforming types that are
        specified in our "sources" method.

        There is no expectation on how the plugin represents 'data', but it would
        make sense to do so in some manner that is not only natural for the plugin,
        but also easily transformed.
        '''
        pass
    

    @coroutine
    @abstractmethod
    def run(self):
        '''Our main method where work happens

        This is the method that will be invoked when the plug-in needs to do
        some work.
        '''
        pass
Beispiel #38
0
class Cloner(object):
    def __init__(self, root):
        self.visited_urls = []
        self.root = self.add_scheme(root)
        if len(self.root.host) < 4:
            sys.exit('invalid taget {}'.format(self.root.host))
        self.target_path = '/opt/snare/pages/{}'.format(self.root.host)

        if not os.path.exists(self.target_path):
            os.mkdir(self.target_path)

        self.new_urls = Queue()

    @staticmethod
    def add_scheme(url):
        if yarl.URL(url).scheme:
            new_url = yarl.URL(url)
        else:
            new_url = yarl.URL('http://' + url)
        return new_url

    @asyncio.coroutine
    def process_link(self, url, check_host=False):
        url = yarl.URL(url)
        if check_host:
            if (url.host != self.root.host or url.fragment
                            or url in self.visited_urls):
                return None
        if not url.is_absolute():
            url = self.root.join(url)

        yield from self.new_urls.put(url)
        return url.relative().human_repr()

    @asyncio.coroutine
    def replace_links(self, data):
        soup = BeautifulSoup(data, 'html.parser')

        # find all relative links
        for link in soup.findAll(href=True):
            res = yield from self.process_link(link['href'], check_host=True)
            if res is not None:
                link['href'] = res

        # find all images and scripts
        for elem in soup.findAll(src=True):
            res = yield from self.process_link(elem['src'])
            if res is not None:
                elem['src'] = res

        # find all action elements
        for act_link in soup.findAll(action=True):
            res = yield from self.process_link(act_link['action'])
            if res is not None:
                act_link['action'] = res

        # prevent redirects
        for redir in soup.findAll(True, attrs={'name': re.compile('redirect.*')}):
            redir['value'] = yarl.URL(redir['value']).relative().human_repr()

        return soup

    @asyncio.coroutine
    def get_body(self):
        while not self.new_urls.empty():
            current_url = yield from self.new_urls.get()
            if current_url in self.visited_urls:
                continue
            self.visited_urls.append(current_url)
            if current_url.name:
                file_name = current_url.name
            elif current_url.raw_path != '/':
                file_name = current_url.path.rsplit('/')[1]
            else:
                file_name = 'index.html'
            file_path = os.path.dirname(current_url.path)
            if file_path == '/':
                file_path = self.target_path
            else:
                file_path = os.path.join(self.target_path, file_path[1:])

            print('path: ', file_path, 'name: ', file_name)

            if file_path and not os.path.exists(file_path):
                os.makedirs(file_path)

            data = None
            try:
                with aiohttp.Timeout(10.0):
                    with aiohttp.ClientSession() as session:
                        response = yield from session.get(current_url)
                        data = yield from response.read()
            except aiohttp.ClientError as client_error:
                print(client_error)
            else:
                response.release()
                session.close()
            if data is not None:
                if re.match(re.compile('.*\.(html|php)'), file_name):
                    soup = yield from self.replace_links(data)
                    data = str(soup).encode()
                with open(os.path.join(file_path, file_name), 'wb') as index_fh:
                    index_fh.write(data)
                if '.css' in file_name:
                    css = cssutils.parseString(data)
                    for carved_url in cssutils.getUrls(css):
                        if carved_url.startswith('data'):
                            continue
                        carved_url = yarl.URL(carved_url)
                        if not carved_url.is_absolute():
                            carved_url = self.root.join(carved_url)
                        if carved_url not in self.visited_urls:
                            yield from self.new_urls.put(carved_url)

    @asyncio.coroutine
    def run(self):
        yield from self.new_urls.put(self.root)
        return (yield from self.get_body())
Beispiel #39
0
class Crawler:
    def __init__(self, root_url, max_redirect):
        self.max_tasks = 10
        self.max_redirect = max_redirect
        self.q = Queue()
        self.seen_urls = set()

        # aiohttp's ClientSession does connection pooling and
        # HTTP keep-alives for us.
        self.session = aiohttp.ClientSession(loop=loop)

        # Put (URL, max_redirect) in the Queue
        self.q.put((root_url, self.max_redirect))
        
    @asyncio.coroutine
    def crawl(self):
        '''Run the crawler untill all work is done.'''
        workers = [asyncio.Task(self.work())
                   for _ in range(self.max_tasks)]

        # When all work is done, exit.
        yield from self.q.join()
        for w in workers:
            w.cancel()

    @asyncio.coroutine
    def work(self):
        while True:
            url, max_redirect = yield from self.q.get()

            # Download page and add new links to self.q
            yield from self.fetch(url, max_redirect)
            self.q.task_done()

    @asyncio.coroutine
    def fetch(self, url, max_redirect):
        # Handle redirects ourselves.
        response = yield from self.session.get(
            url, allow_redirects=False)

        try:
            if is_redirect(response):
                if max_redirect > 0:
                    next_url = response.headers['location']
                    if next_url in self.seen_urls:
                        # We have done this before.
                        return

                    # Remember we have seen this url.
                    self.seen_urls.add(next_url)

                    # Follow the redirect. One less redirect remains.
                    self.q.put_nowait((next_url, max_redirect -1))
            else:
                links = yield from self.parse_links(response)
                # Python set-logic:
                for link in links.difference(self.seen_urls):
                    self.q.put_nowait((link, self.max_redirect))
                self.seen_urls.update(links)
        finally:
            # Return connection to pool.
            yield from response.release()