예제 #1
0
class KthLargest:
    def __init__(self, k, nums):
        self.k = k
        self.pq = PriorityQueue()
        for i in nums:
            self.pq.put(i)

    def add(self, n):
        if self.pq.qsize() < self.k:
            self.pq.put(n)

        return self.pq.get_nowait()
예제 #2
0
파일: peers.py 프로젝트: wangroot/trinity
class WaitingPeers(Generic[TChainPeer]):
    """
    Peers waiting to perform some action. When getting a peer from this queue,
    prefer the peer with the best throughput for the given command.
    """
    _waiting_peers: 'PriorityQueue[SortableTask[TChainPeer]]'
    _response_command_type: Tuple[Type[CommandAPI[Any]], ...]

    def __init__(
        self,
        response_command_type: Union[Type[CommandAPI[Any]],
                                     Sequence[Type[CommandAPI[Any]]]],
        sort_key: Callable[[PerformanceAPI],
                           float] = _items_per_second) -> None:
        """
        :param sort_key: how should we sort the peers to get the fastest? low score means top-ranked
        """
        self._waiting_peers = PriorityQueue()

        if isinstance(response_command_type, type):
            self._response_command_type = (response_command_type, )
        elif isinstance(response_command_type, collections.abc.Sequence):
            self._response_command_type = tuple(response_command_type)
        else:
            raise TypeError(f"Unsupported value: {response_command_type}")

        self._peer_wrapper = SortableTask.orderable_by_func(
            self._get_peer_rank)
        self._sort_key = sort_key

    def _get_peer_rank(self, peer: TChainPeer) -> float:
        scores = [
            self._sort_key(exchange.tracker)
            for exchange in peer.chain_api.exchanges if issubclass(
                exchange.get_response_cmd_type(), self._response_command_type)
        ]

        if len(scores) == 0:
            raise ValidationError(
                f"Could not find any exchanges on {peer} "
                f"with response {self._response_command_type!r}")

        # Typically there will only be one score, but we might want to match multiple commands.
        # To handle that case, we take the average of the scores:
        return sum(scores) / len(scores)

    def put_nowait(self, peer: TChainPeer) -> None:
        try:
            wrapped_peer = self._peer_wrapper(peer)
        except PeerConnectionLost:
            return
        self._waiting_peers.put_nowait(wrapped_peer)

    async def get_fastest(self) -> TChainPeer:
        wrapped_peer = await self._waiting_peers.get()
        peer = wrapped_peer.original

        # make sure the peer has not gone offline while waiting in the queue
        while not peer.is_alive:
            # if so, look for the next best peer
            wrapped_peer = await self._waiting_peers.get()
            peer = wrapped_peer.original

        return peer

    def pop_nowait(self) -> TChainPeer:
        """
        :raise QueueEmpty: if no peer is available
        """
        wrapped_peer = self._waiting_peers.get_nowait()
        peer = wrapped_peer.original

        # make sure the peer has not gone offline while waiting in the queue
        while not peer.is_alive:
            # if so, look for the next best peer
            wrapped_peer = self._waiting_peers.get_nowait()
            peer = wrapped_peer.original

        return peer

    def __len__(self) -> int:
        return self._waiting_peers.qsize()
예제 #3
0
class SubNameBrute(object):
    def __init__(self, *params):
        self.domain, self.options, self.process_num, self.dns_servers, self.next_subs, \
            self.scan_count, self.found_count, self.queue_size_array, tmp_dir = params
        self.dns_count = len(self.dns_servers)
        self.scan_count_local = 0
        self.found_count_local = 0
        self.resolvers = [
            aiodns.DNSResolver(tries=1) for _ in range(self.options.threads)
        ]
        self.queue = PriorityQueue()
        self.ip_dict = {}
        self.found_subs = set()
        self.timeout_subs = {}
        self.count_time = time.time()
        self.outfile = open(
            '%s/%s_part_%s.txt' % (tmp_dir, self.domain, self.process_num),
            'w')
        self.normal_names_set = set()
        self.lock = asyncio.Lock()
        self.loop = None
        self.threads_status = ['1'] * self.options.threads

    async def load_sub_names(self):
        normal_lines = []
        wildcard_lines = []
        wildcard_set = set()
        regex_list = []
        lines = set()
        with open(self.options.file) as inFile:
            for line in inFile.readlines():
                sub = line.strip()
                if not sub or sub in lines:
                    continue
                lines.add(sub)

                brace_count = sub.count('{')
                if brace_count > 0:
                    wildcard_lines.append((brace_count, sub))
                    sub = sub.replace('{alphnum}', '[a-z0-9]')
                    sub = sub.replace('{alpha}', '[a-z]')
                    sub = sub.replace('{num}', '[0-9]')
                    if sub not in wildcard_set:
                        wildcard_set.add(sub)
                        regex_list.append('^' + sub + '$')
                else:
                    normal_lines.append(sub)
                    self.normal_names_set.add(sub)

        if regex_list:
            pattern = '|'.join(regex_list)
            _regex = re.compile(pattern)
            for line in normal_lines:
                if _regex.search(line):
                    normal_lines.remove(line)

        for _ in normal_lines[self.process_num::self.options.process]:
            await self.queue.put((0, _))  # priority set to 0
        for _ in wildcard_lines[self.process_num::self.options.process]:
            await self.queue.put(_)

    async def scan(self, j):
        self.resolvers[j].nameservers = [self.dns_servers[j % self.dns_count]]
        if self.dns_count > 1:
            while True:
                s = random.choice(self.resolvers)
                if s != self.dns_servers[j % self.dns_count]:
                    self.resolvers[j].nameservers.append(s)
                    break
        while True:
            try:
                if time.time() - self.count_time > 1.0:
                    async with self.lock:
                        self.scan_count.value += self.scan_count_local
                        self.scan_count_local = 0
                        self.queue_size_array[
                            self.process_num] = self.queue.qsize()
                        if self.found_count_local:
                            self.found_count.value += self.found_count_local
                            self.found_count_local = 0
                        self.count_time = time.time()

                try:
                    brace_count, sub = self.queue.get_nowait()
                    self.threads_status[j] = '1'
                except asyncio.queues.QueueEmpty as e:
                    self.threads_status[j] = '0'
                    await asyncio.sleep(0.5)
                    if '1' not in self.threads_status:
                        break
                    else:
                        continue

                if brace_count > 0:
                    brace_count -= 1
                    if sub.find('{next_sub}') >= 0:
                        for _ in self.next_subs:
                            await self.queue.put(
                                (0, sub.replace('{next_sub}', _)))
                    if sub.find('{alphnum}') >= 0:
                        for _ in 'abcdefghijklmnopqrstuvwxyz0123456789':
                            await self.queue.put(
                                (brace_count, sub.replace('{alphnum}', _, 1)))
                    elif sub.find('{alpha}') >= 0:
                        for _ in 'abcdefghijklmnopqrstuvwxyz':
                            await self.queue.put(
                                (brace_count, sub.replace('{alpha}', _, 1)))
                    elif sub.find('{num}') >= 0:
                        for _ in '0123456789':
                            await self.queue.put(
                                (brace_count, sub.replace('{num}', _, 1)))
                    continue
            except Exception as e:
                import traceback
                print(traceback.format_exc())
                break

            try:

                if sub in self.found_subs:
                    continue

                self.scan_count_local += 1
                cur_domain = sub + '.' + self.domain
                # print('Query %s' % cur_domain)
                answers = await self.resolvers[j].query(cur_domain, 'A')

                if answers:
                    self.found_subs.add(sub)
                    ips = ', '.join(sorted([answer.host
                                            for answer in answers]))
                    if ips in ['1.1.1.1', '127.0.0.1', '0.0.0.0', '0.0.0.1']:
                        continue
                    if self.options.i and is_intranet(answers[0].host):
                        continue

                    try:
                        self.scan_count_local += 1
                        answers = await self.resolvers[j].query(
                            cur_domain, 'CNAME')
                        cname = answers[0].target.to_unicode().rstrip('.')
                        if cname.endswith(
                                self.domain) and cname not in self.found_subs:
                            cname_sub = cname[:len(cname) - len(self.domain) -
                                              1]  # new sub
                            if cname_sub not in self.normal_names_set:
                                self.found_subs.add(cname)
                                await self.queue.put((0, cname_sub))
                    except Exception as e:
                        pass

                    first_level_sub = sub.split('.')[-1]
                    if (first_level_sub, ips) not in self.ip_dict:
                        self.ip_dict[(first_level_sub, ips)] = 1
                    else:
                        self.ip_dict[(first_level_sub, ips)] += 1
                        if self.ip_dict[(first_level_sub, ips)] > 30:
                            continue

                    self.found_count_local += 1

                    self.outfile.write(
                        cur_domain.ljust(30) + '\t' + ips + '\n')
                    self.outfile.flush()
                    try:
                        self.scan_count_local += 1
                        await self.resolvers[j].query(
                            'lijiejie-test-not-existed.' + cur_domain, 'A')
                    except aiodns.error.DNSError as e:
                        if e.args[0] in [4]:
                            if self.queue.qsize() < 50000:
                                for _ in self.next_subs:
                                    await self.queue.put((0, _ + '.' + sub))
                            else:
                                await self.queue.put((1, '{next_sub}.' + sub))
                    except Exception as e:
                        pass

            except aiodns.error.DNSError as e:
                if e.args[0] in [1, 4]:
                    pass
                elif e.args[0] in [
                        11, 12
                ]:  # 12 timeout   # (11, 'Could not contact DNS servers')
                    # print('timed out sub %s' % sub)
                    self.timeout_subs[sub] = self.timeout_subs.get(sub, 0) + 1
                    if self.timeout_subs[sub] <= 1:
                        await self.queue.put((0, sub))  # Retry
                else:
                    print(e)
            except asyncio.TimeoutError as e:
                pass
            except Exception as e:
                import traceback
                traceback.print_exc()
                with open('errors.log', 'a') as errFile:
                    errFile.write('[%s] %s\n' % (type(e), str(e)))

    async def async_run(self):
        await self.load_sub_names()
        tasks = [self.scan(i) for i in range(self.options.threads)]
        await asyncio.gather(*tasks)

    def run(self):
        self.loop = asyncio.get_event_loop()
        asyncio.set_event_loop(self.loop)
        self.loop.run_until_complete(self.async_run())
예제 #4
0
class TaskQueue(Generic[TTask]):
    """
    TaskQueue keeps priority-order track of pending tasks, with a limit on number pending.

    A producer of tasks will insert pending tasks with await add(), which will not return until
    all tasks have been added to the queue.

    A task consumer calls await get() to retrieve tasks for processing. Tasks will be returned in
    priority order. If no tasks are pending, get()
    will pause until at least one is available. Only one consumer will have a task "checked out"
    from get() at a time.

    After tasks are successfully completed, the consumer will call complete() to remove them from
    the queue. The consumer doesn't need to complete all tasks, but any uncompleted tasks will be
    considered abandoned. Another consumer can pick it up at the next get() call.
    """

    # a function that determines the priority order (lower int is higher priority)
    _order_fn: FunctionProperty[Callable[[TTask], Any]]

    # batches of tasks that have been started but not completed
    _in_progress: Dict[int, Tuple[TTask, ...]]

    # all tasks that have been placed in the queue and have not been started
    _open_queue: 'PriorityQueue[Tuple[Any, TTask]]'

    # all tasks that have been placed in the queue and have not been completed
    _tasks: Set[TTask]

    def __init__(self,
                 maxsize: int = 0,
                 order_fn: Callable[[TTask], Any] = identity,
                 *,
                 loop: AbstractEventLoop = None) -> None:
        self._maxsize = maxsize
        self._full_lock = Lock(loop=loop)
        self._open_queue = PriorityQueue(maxsize, loop=loop)
        self._order_fn = order_fn
        self._id_generator = count()
        self._tasks = set()
        self._in_progress = {}

    async def add(self, tasks: Tuple[TTask, ...]) -> None:
        """
        add() will insert as many tasks as can be inserted until the queue fills up.
        Then it will pause until the queue is no longer full, and continue adding tasks.
        It will finally return when all tasks have been inserted.
        """
        if not isinstance(tasks, tuple):
            raise ValidationError(
                f"must pass a tuple of tasks to add(), but got {tasks!r}")

        already_pending = self._tasks.intersection(tasks)
        if already_pending:
            raise ValidationError(
                f"Duplicate tasks detected: {already_pending!r} are already present in the queue"
            )

        # make sure to insert the highest-priority items first, in case queue fills up
        remaining = tuple(
            sorted((self._order_fn(task), task) for task in tasks))

        while remaining:
            num_tasks = len(self._tasks)

            if self._maxsize <= 0:
                # no cap at all, immediately insert all tasks
                open_slots = len(remaining)
            elif num_tasks < self._maxsize:
                # there is room to add at least one more task
                open_slots = self._maxsize - num_tasks
            else:
                # wait until there is room in the queue
                await self._full_lock.acquire()

                # the current number of tasks has changed, can't reuse num_tasks
                num_tasks = len(self._tasks)
                open_slots = self._maxsize - num_tasks

            queueing, remaining = remaining[:open_slots], remaining[
                open_slots:]

            for task in queueing:
                # There will always be room in _open_queue until _maxsize is reached
                try:
                    self._open_queue.put_nowait(task)
                except QueueFull as exc:
                    task_idx = queueing.index(task)
                    qsize = self._open_queue.qsize()
                    raise QueueFull(
                        f'TaskQueue unsuccessful in adding task {task[1]!r} because qsize={qsize}, '
                        f'num_tasks={num_tasks}, maxsize={self._maxsize}, open_slots={open_slots}, '
                        f'num queueing={len(queueing)}, len(_tasks)={len(self._tasks)}, task_idx='
                        f'{task_idx}, queuing={queueing}, original msg: {exc}',
                    )

            unranked_queued = tuple(task for _rank, task in queueing)
            self._tasks.update(unranked_queued)

            if self._full_lock.locked() and len(self._tasks) < self._maxsize:
                self._full_lock.release()

    def get_nowait(self,
                   max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
        """
        Get pending tasks. If no tasks are pending, raise an exception.

        :param max_results: return up to this many pending tasks. If None, return all pending tasks.
        :return: (batch_id, tasks to attempt)
        :raise ~asyncio.QueueFull: if no tasks are available
        """
        if self._open_queue.empty():
            raise QueueFull("No tasks are available to get")
        else:
            pending_tasks = self._get_nowait(max_results)

            # Generate a pending batch of tasks, so uncompleted tasks can be inferred
            next_id = next(self._id_generator)
            self._in_progress[next_id] = pending_tasks

            return (next_id, pending_tasks)

    async def get(self,
                  max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
        """
        Get pending tasks. If no tasks are pending, wait until a task is added.

        :param max_results: return up to this many pending tasks. If None, return all pending tasks.
        :return: (batch_id, tasks to attempt)
        """
        if max_results is not None and max_results < 1:
            raise ValidationError(
                "Must request at least one task to process, not {max_results!r}"
            )

        # if the queue is empty, wait until at least one item is available
        queue = self._open_queue
        if queue.empty():
            _rank, first_task = await queue.get()
        else:
            _rank, first_task = queue.get_nowait()

        # In order to return from get() as soon as possible, never await again.
        # Instead, take only the tasks that are already available.
        if max_results is None:
            remaining_count = None
        else:
            remaining_count = max_results - 1
        remaining_tasks = self._get_nowait(remaining_count)

        # Combine the first and remaining tasks
        all_tasks = (first_task, ) + remaining_tasks

        # Generate a pending batch of tasks, so uncompleted tasks can be inferred
        next_id = next(self._id_generator)
        self._in_progress[next_id] = all_tasks

        return (next_id, all_tasks)

    def _get_nowait(self, max_results: int = None) -> Tuple[TTask, ...]:
        queue = self._open_queue

        # How many results do we want?
        available = queue.qsize()
        if max_results is None:
            num_tasks = available
        else:
            num_tasks = min((available, max_results))

        # Combine the remaining tasks with the first task we already pulled.
        ranked_tasks = tuple(queue.get_nowait() for _ in range(num_tasks))

        # strip out the rank value used internally for sorting in the priority queue
        return tuple(task for _rank, task in ranked_tasks)

    def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
        if batch_id not in self._in_progress:
            raise ValidationError(
                f"batch id {batch_id} not recognized, with tasks {completed!r}"
            )

        attempted = self._in_progress.pop(batch_id)

        unrecognized_tasks = set(completed).difference(attempted)
        if unrecognized_tasks:
            self._in_progress[batch_id] = attempted
            raise ValidationError(
                f"cannot complete tasks {unrecognized_tasks!r} in this batch, only {attempted!r}"
            )

        incomplete = set(attempted).difference(completed)

        for task in incomplete:
            # These tasks are already counted in the total task count, so there will be room
            self._open_queue.put_nowait((self._order_fn(task), task))

        self._tasks.difference_update(completed)

        if self._full_lock.locked() and len(self._tasks) < self._maxsize:
            self._full_lock.release()

    def __contains__(self, task: TTask) -> bool:
        """Determine if a task has been added and not yet completed"""
        return task in self._tasks
예제 #5
0
class SubNameBrute(object):
    def __init__(self, *params):
        self.domain, self.options, self.process_num, self.dns_servers, self.next_subs, \
            self.scan_count, self.found_count, self.queue_size_array, tmp_dir = params
        self.dns_count = len(self.dns_servers)
        self.scan_count_local = 0
        self.found_count_local = 0
        self.resolvers = [
            dns.asyncresolver.Resolver(configure=False)
            for _ in range(self.options.threads)
        ]
        for r in self.resolvers:
            r.lifetime = 6.0
            r.timeout = 10.0
        self.queue = PriorityQueue()
        self.ip_dict = {}
        self.found_subs = set()
        self.cert_subs = set()
        self.timeout_subs = {}
        self.no_server_subs = {}
        self.count_time = time.time()
        self.outfile = open(
            '%s/%s_part_%s.txt' % (tmp_dir, self.domain, self.process_num),
            'w')
        self.normal_names_set = set()
        self.lock = asyncio.Lock()
        self.threads_status = ['1'] * self.options.threads

    async def load_sub_names(self):
        normal_lines = []
        wildcard_lines = []
        wildcard_set = set()
        regex_list = []
        lines = set()
        with open(self.options.file) as inFile:
            for line in inFile.readlines():
                sub = line.strip()
                if not sub or sub in lines:
                    continue
                lines.add(sub)

                brace_count = sub.count('{')
                if brace_count > 0:
                    wildcard_lines.append((brace_count, sub))
                    sub = sub.replace('{alphnum}', '[a-z0-9]')
                    sub = sub.replace('{alpha}', '[a-z]')
                    sub = sub.replace('{num}', '[0-9]')
                    if sub not in wildcard_set:
                        wildcard_set.add(sub)
                        regex_list.append('^' + sub + '$')
                else:
                    normal_lines.append(sub)
                    self.normal_names_set.add(sub)

        if regex_list:
            pattern = '|'.join(regex_list)
            _regex = re.compile(pattern)
            for line in normal_lines:
                if _regex.search(line):
                    normal_lines.remove(line)

        for _ in normal_lines[self.process_num::self.options.process]:
            await self.queue.put((0, _))  # priority set to 0
        for _ in wildcard_lines[self.process_num::self.options.process]:
            await self.queue.put(_)

    async def update_counter(self):
        while True:
            if '1' not in self.threads_status:
                return
            self.scan_count.value += self.scan_count_local
            self.scan_count_local = 0
            self.queue_size_array[self.process_num] = self.queue.qsize()
            if self.found_count_local:
                self.found_count.value += self.found_count_local
                self.found_count_local = 0
            self.count_time = time.time()
            await asyncio.sleep(0.5)

    async def check_https_alt_names(self, domain):
        try:
            reader, _ = await asyncio.open_connection(
                host=domain,
                port=443,
                ssl=True,
                server_hostname=domain,
            )
            for item in reader._transport.get_extra_info(
                    'peercert')['subjectAltName']:
                if item[0].upper() == 'DNS':
                    name = item[1].lower()
                    if name.endswith(self.domain):
                        sub = name[:len(name) - len(self.domain) -
                                   1]  # new sub
                        sub = sub.replace('*', '')
                        sub = sub.strip('.')
                        if sub and sub not in self.found_subs and \
                                sub not in self.normal_names_set and sub not in self.cert_subs:
                            self.cert_subs.add(sub)
                            await self.queue.put((0, sub))
        except Exception as e:
            pass

    async def do_query(self, j, cur_domain):
        async with timeout(10.2):
            return await self.resolvers[j].resolve(cur_domain, 'A')
        # asyncio.wait_for did not work properly
        # hang up in some cases, we use async_timeout instead
        # return await asyncio.wait_for(self.resolvers[j].resolve(cur_domain, 'A', lifetime=8), timeout=9)

    async def scan(self, j):
        self.resolvers[j].nameservers = [self.dns_servers[j % self.dns_count]]
        if self.dns_count > 1:
            while True:
                s = random.choice(self.dns_servers)
                if s != self.dns_servers[j % self.dns_count]:
                    self.resolvers[j].nameservers.append(s)
                    break
        empty_counter = 0
        while True:
            try:
                brace_count, sub = self.queue.get_nowait()
                self.threads_status[j] = '1'
                empty_counter = 0
            except asyncio.queues.QueueEmpty as e:
                empty_counter += 1
                if empty_counter > 10:
                    self.threads_status[j] = '0'
                if '1' not in self.threads_status:
                    break
                else:
                    await asyncio.sleep(0.1)
                    continue

            if brace_count > 0:
                brace_count -= 1
                if sub.find('{next_sub}') >= 0:
                    for _ in self.next_subs:
                        await self.queue.put((0, sub.replace('{next_sub}', _)))
                if sub.find('{alphnum}') >= 0:
                    for _ in 'abcdefghijklmnopqrstuvwxyz0123456789':
                        await self.queue.put(
                            (brace_count, sub.replace('{alphnum}', _, 1)))
                elif sub.find('{alpha}') >= 0:
                    for _ in 'abcdefghijklmnopqrstuvwxyz':
                        await self.queue.put(
                            (brace_count, sub.replace('{alpha}', _, 1)))
                elif sub.find('{num}') >= 0:
                    for _ in '0123456789':
                        await self.queue.put(
                            (brace_count, sub.replace('{num}', _, 1)))
                continue

            try:
                if sub in self.found_subs:
                    continue

                self.scan_count_local += 1
                cur_domain = sub + '.' + self.domain

                answers = await self.do_query(j, cur_domain)
                if answers:
                    self.found_subs.add(sub)
                    ips = ', '.join(
                        sorted([answer.address for answer in answers]))
                    invalid_ip_found = False
                    for answer in answers:
                        if answer.address in [
                                '1.1.1.1', '127.0.0.1', '0.0.0.0', '0.0.0.1'
                        ]:
                            invalid_ip_found = True
                    if invalid_ip_found:
                        continue
                    if self.options.i and is_intranet(answers[0].host):
                        continue

                    try:
                        cname = str(answers.canonical_name)[:-1]
                        if cname != cur_domain and cname.endswith(self.domain):
                            cname_sub = cname[:len(cname) - len(self.domain) -
                                              1]  # new sub
                            if cname_sub not in self.found_subs and cname_sub not in self.normal_names_set:
                                await self.queue.put((0, cname_sub))
                    except Exception as e:
                        pass

                    first_level_sub = sub.split('.')[-1]
                    max_found = 20

                    if self.options.w:
                        first_level_sub = ''
                        max_found = 3

                    if (first_level_sub, ips) not in self.ip_dict:
                        self.ip_dict[(first_level_sub, ips)] = 1
                    else:
                        self.ip_dict[(first_level_sub, ips)] += 1
                        if self.ip_dict[(first_level_sub, ips)] > max_found:
                            continue

                    self.found_count_local += 1

                    self.outfile.write(
                        cur_domain.ljust(30) + '\t' + ips + '\n')
                    self.outfile.flush()

                    if not self.options.no_cert_check:
                        async with timeout(10.0):
                            await self.check_https_alt_names(cur_domain)

                    try:
                        self.scan_count_local += 1
                        await self.do_query(
                            j, 'lijiejie-test-not-existed.' + cur_domain)

                    except dns.resolver.NXDOMAIN as e:
                        if self.queue.qsize() < 20000:
                            for _ in self.next_subs:
                                await self.queue.put((0, _ + '.' + sub))
                        else:
                            await self.queue.put((1, '{next_sub}.' + sub))
                    except Exception as e:
                        continue

            except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e:
                pass
            except dns.resolver.NoNameservers as e:
                self.no_server_subs[sub] = self.no_server_subs.get(sub, 0) + 1
                if self.no_server_subs[sub] <= 3:
                    await self.queue.put((0, sub))  # Retry again
            except (dns.exception.Timeout, dns.resolver.LifetimeTimeout) as e:
                self.timeout_subs[sub] = self.timeout_subs.get(sub, 0) + 1
                if self.timeout_subs[sub] <= 3:
                    await self.queue.put((0, sub))  # Retry again
            except Exception as e:
                if str(type(e)).find('asyncio.exceptions.TimeoutError') < 0:
                    with open('errors.log', 'a') as errFile:
                        errFile.write('[%s] %s\n' % (type(e), str(e)))

    async def async_run(self):
        await self.load_sub_names()
        tasks = [self.scan(i) for i in range(self.options.threads)]
        tasks.insert(0, self.update_counter())
        await asyncio.gather(*tasks)

    def run(self):
        loop = asyncio.get_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(self.async_run())
예제 #6
0
class JsonServer:
    """
    Module can receive a JSON string and if the action is “apply” it adds the
    contents of the “template” into a priority queue. The module also processes
    the queue every second and saves the template to a file.

    Example JSON string:
    {"action": "apply", "when": "2016-4-19 09:00:02", "template": "AAAAAA"}
    """
    def __init__(self):
        """
        Initialize and run an asyncio event loop for ever.
        """
        self.loop = asyncio.get_event_loop()
        self.queue = PriorityQueue(loop=self.loop)
        self.loop.create_task(self.json_server(('', 25000)))
        self.loop.create_task(self.queue_dumper())
        self.loop.run_forever()

    async def json_server(self, address):
        """
        Creates server connection at the given address.
        :param address: Tuple (host, port) for eg. ('' 25000)
        """
        sock = socket(AF_INET, SOCK_STREAM)
        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
        sock.bind(address)
        sock.listen(5)
        sock.setblocking(False)
        while True:
            client, addr = await self.loop.sock_accept(sock)
            self.loop.create_task(self.json_handler(client))

    async def json_handler(self, client):
        """
        Accepts incoming client connections.
        :param client: socket client
        """
        with client:
            while True:
                raw_data = await self.loop.sock_recv(client, 10000)
                if not raw_data:
                    break
                try:
                    data = json.loads(raw_data.strip().decode("utf-8"))
                except:
                    # In case a valid json dose not come in ignore incoming
                    # message and ignore it.
                    # TODO: Log it.
                    await self.loop.sock_sendall(client,
                                                 b'Rejected: ' + raw_data)
                    break
                if self.is_valid_data_input(
                        data) and data['action'] == 'apply':
                    ts = datetime.strptime(data['when'], '%Y-%m-%d %H:%M:%S')
                    await self.queue.put((ts, data['template']))
                    await self.loop.sock_sendall(client,
                                                 b'Accepted: ' + raw_data)
                else:
                    await self.loop.sock_sendall(client,
                                                 b'Rejected: ' + raw_data)

    def is_valid_data_input(self, data):
        """
        Validates incoming data.
        :param data: dictionary
        :return:  'True' if valid, else 'False'.
        """
        if not data.get('action'):
            return False
        if not data.get('when'):
            return False
        if not data.get('template'):
            return False
        return True

    async def queue_dumper(self):
        """
        Dumps status of queue to terminal (Eventually a file) every second.
        """
        # TODO : Ensure this also prints to a file.
        while True:
            if not self.queue.qsize():
                await asyncio.sleep(1)
            else:
                _copy = PriorityQueue()
                while not self.queue.empty():
                    await _copy.put(await self.queue.get())
                print(chr(27) + "[2J")  # Bit of Ctr + L magic trick
                while not _copy.empty():
                    element = await _copy.get()
                    print(element)
                    await self.queue.put(element)
            await asyncio.sleep(1)