예제 #1
0
class InMemoryCacheHandler:
    def __init__(self):
        dns_cache_size = int(setup_config_data_holder.dns_cache_size)
        dns_cache_ttl = int(setup_config_data_holder.dns_cache_ttl)
        self.cache = TTLCache(maxsize=dns_cache_size, ttl=dns_cache_ttl)

    def add_to_cache(self, key, value):
        self.cache[key] = value
        return self.cache

    def get_from_cache(self, key):
        return self.cache.get(key)

    def check_if_present(self, required_key):
        if required_key in self.cache.keys():
            return True
        else:
            return False

    def show_cache(self):
        logger.debug(f"[Process: Show Cache], "
                     f"CurrentCacheSize: [{self.cache.currsize}], "
                     f"[MaxCacheSize: {self.cache.maxsize}], "
                     f"[CacheTTL: {self.cache.ttl}], "
                     f"[CacheTimer: {self.cache.timer}]")

    def clear_cache(self):
        logger.debug(f"[Process: Clearing Cache], "
                     f"CurrentCacheSize: [{self.cache.currsize}], "
                     f"[MaxCacheSize: {self.cache.maxsize}], "
                     f"[CacheTTL: {self.cache.ttl}], "
                     f"[CacheTimer: {self.cache.timer}]")
        self.cache.clear()
예제 #2
0
class WebSocket:

    def __init__(self, host, port, maxWebsockets, socketsTTL, processor):
        self.sockets=TTLCache(maxsize=maxWebsockets, ttl=socketsTTL)
        self.t=None
        self.host=host
        self.port=port
        self.processor=processor

    def start(self):
        self.t=Thread(target=self.serve_forever)
        self.t.start()
        logging.info("WEBSOCKET STARTED IN "+self.host+":"+str(self.port))

    def stop(self):
        self.t._stop()

    def serve_forever(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        start_server = websockets.serve(self.handler, self.host, self.port)
        asyncio.get_event_loop().run_until_complete(start_server)
        asyncio.get_event_loop().run_forever()

    async def handler(self, websocket, path):
        while True:
            data = await websocket.recv()
            logging.info("WEBSOCKET RECEIVED "+data)
            jsonData=json.loads(data)
            token=jsonData["token"]
            if token in self.sockets:
                self.sockets[token].close()
            self.sockets[token]=websocket
            self.processor.checkPermissions(token)

    async def send(self, data, token):
        socket = self.sockets.get(token)
        if socket:
            await socket.send(data)

    def getUsers(self):
        return self.sockets.keys()

    def delToken(self, token):
        if token in self.sockets:
            self.sockets[token].close()
            del self.sockets[token]
예제 #3
0
class Shiba:

    def __init__(self, bot):
        self.bot = bot
        self.session = requests.Session()
        self.cache = TTLCache(maxsize=500, ttl=300)

        public_key = config.flickrPublic
        secret_key = config.flickrSecret
        flickr_api.set_keys(public_key, secret_key)

    @commands.command(pass_context=True, no_pm=True)
    async def shiba(self, ctx, *args):
        """: !shiba             | post random Shibas"""

        shibas = flickr_api.Photo.search(per_page=500, tags='shiba dog')
        random.shuffle(shibas)
        sizes = ['Large', 'Medium', 'Small', 'Original']

        while shibas[0]['id'] in self.cache.keys():
            random.shuffle(shibas)

        photo_id = shibas[0]['id']
        self.cache[photo_id] = photo_id

        for size in sizes:
            try:
                r = self.session.get(shibas[0].getPhotoUrl(size))
                strainer = SoupStrainer('div', attrs={'id': 'allsizes-photo'})
                soup = BeautifulSoup(r.content, "html.parser", parse_only=strainer)

                file_url = soup.find('img')['src']

                embed = Embed()
                embed.set_author(name="Shiba", icon_url="http://i1.kym-cdn.com/entries/icons/facebook/000/013/564/aP2dv.jpg")
                embed.set_image(url=file_url)
                await self.bot.send_message(ctx.message.channel, embed=embed)
                return

            except:
                continue
예제 #4
0
class OrgAPI(object):
    def __init__(self,
                 symbol,
                 admin_mode=False,
                 url=DEFAULT_RSI_URL,
                 endpoint='/orgs',
                 members_endpoint='/api/orgs/getOrgMembers',
                 cache_ttl=DEFAULT_CACHE_TTL,
                 session=None,
                 username='',
                 password=''):
        self.session = session
        self.symbol = symbol
        self.url = url.rstrip('/')
        self.endpoint = endpoint
        self.members_endpoint = members_endpoint
        self.admin_mode = admin_mode

        self.org_url = "{}/{}/{}".format(self.url, self.endpoint.lstrip('/'),
                                         symbol)
        self.members_api = "{}/{}".format(self.url,
                                          self.members_endpoint.lstrip('/'))
        self._ttlcache = TTLCache(maxsize=1, ttl=cache_ttl)

        if self.session is None:
            self.session = RSISession(url=url)
            if username and password:
                self.session.authenticate(username, password)

        self._update_details(
        )  # pull and cache the org details which will raise 404 if not found

    def clear_cache(self):
        """ Resets the cache """
        for key in self._ttlcache.keys():
            del self._ttlcache[key]

    def _cache(self, key, update_func, *args, **kwargs):
        if key not in self._ttlcache:
            self._ttlcache[key] = update_func(*args, **kwargs)
        return self._ttlcache[key]

    def _update_members(self, search):
        members = []

        params = {'symbol': self.symbol, 'search': search, 'page': 1}

        if self.admin_mode:
            params['admin_mode'] = 1

        # this just gets us going
        totalsize = 1

        while len(members) < totalsize:
            r = self.session.post(self.members_api, data=params)

            if r.status_code == 200:
                r = r.json()
                if r is None:
                    continue

                if 'data' in r and r['data'] and 'totalrows' in r['data']:
                    totalsize = int(r['data']['totalrows'])

                if r['success'] == 1:
                    apisoup = BeautifulSoup(r['data']['html'], features='lxml')
                    for member in apisoup.select('.member-item'):
                        members.append({
                            'name':
                            member.select_one('.name').text,
                            'handle':
                            member.select_one('.nick').text,
                            'avatar':
                            '{}{}'.format(
                                self.url,
                                member.select_one('img').attrs['src']),
                            'affiliate':
                            member.select_one('.title').text == 'Affiliate',
                            'rank':
                            member.select_one('.rank').text,
                            'roles':
                            [_.text for _ in member.select('.rolelist .role')],
                            'url':
                            '{}{}'.format(
                                self.url,
                                member.select_one(
                                    'a.membercard').attrs['href']),
                        })

                        if self.admin_mode:
                            members[-1].update({
                                'id':
                                member.attrs.get('data-member-id', ''),
                                'last_online':
                                member.select_one(
                                    '.frontinfo .lastonline').text,
                                'visibility ':
                                member.select_one(
                                    '.frontinfo .visibility').text,
                            })

                    params['page'] = params['page'] + 1
                else:
                    raise ValueError(
                        'Received error fetching Org members: {}'.format(r))
            else:
                raise Exception(
                    'Received error fetching Org members: {}'.format(
                        r.status_code))
            time.sleep(0.5)
        return members

    def _update_details(self):
        r = self.session.get(self.org_url)
        data = {}
        r.raise_for_status()

        orgsoup = BeautifulSoup(r.text, features='lxml')
        data['banner'] = '{}{}'.format(
            self.url,
            orgsoup.select_one('.banner img')['src'])
        data['logo'] = '{}{}'.format(self.url,
                                     orgsoup.select_one('.logo img')['src'])
        data['name'], data['symbol'] = orgsoup.select_one(
            '.inner h1').text.split(' / ')
        data['model'] = orgsoup.select_one('.inner .tags .model').text
        data['commitment'] = orgsoup.select_one(
            '.inner .tags .commitment').text
        data['primary_focus'] = orgsoup.select_one(
            '.inner .focus .primary img')['alt']
        data['secondary_focus'] = orgsoup.select_one(
            '.inner .focus .secondary img')['alt']
        data['join_us'] = orgsoup.select_one('.join-us .body').text.strip()
        return data

    def search(self, handle, score_cutoff=80, limit=None):
        """
        Return members that match the given handle using fuzzy matching.

        :param handle: Handle to match
        :param score_cutoff: minimum matching score to return
        :param limit: limit the number of matches found
        :return: List of matched results in the form of [(dict, int)] where dict is the ship data and in is the
                 matching confidence
        """
        choices = {i: _['name'] for i, _ in enumerate(self.members)}
        return [(self.members[_[2]], _[1]) for _ in process.extractBests(
            handle, choices, score_cutoff=score_cutoff, limit=limit)]

    def search_one(self, handle):
        """
        Return the first member that matches the given handle using fuzzy matching, or None

        :param handle: Handle to match
        :return: The best matching member, or None
        """
        choices = self.search(handle, limit=1)
        if choices:
            return choices[0][0]
        return None

    @property
    def members(self):
        return self._cache('members', self._update_members, search='')

    @property
    def details(self):
        return self._cache('details', self._update_details)

    @property
    def banner(self):
        return self.details['banner']

    @property
    def logo(self):
        return self.details['logo']

    @property
    def name(self):
        return self.details['name']

    @property
    def model(self):
        return self.details['model']

    @property
    def commitment(self):
        return self.details['commitment']

    @property
    def primary_focus(self):
        return self.details['primary_focus']

    @property
    def secondary_focus(self):
        return self.details['secondary_focus']

    @property
    def spectrum_url(self):
        return '{}/spectrum/community/{}'.format(self.url, self.symbol)

    @property
    def join_us(self):
        return self.details['join_us']
예제 #5
0
class Bot:
    """The twitter bot."""

    # Consts
    CALL_TEXT = "{} second encrypted call at {}"
    HASHTAGS = "#SeattleEncryptedComms"
    TWEET_PADDING = 20
    BASE_URL = "https://api.openmhz.com/kcers1b/calls/newer?time={}&filter-type=talkgroup&filter-code=44912,45040,45112,45072,45136"

    # DEBUG URL TO GET A LOT OF API RESPONSES
    # BASE_URL = "https://api.openmhz.com/kcers1b/calls/newer?time={}"

    def __init__(self) -> None:
        """Initializes the class."""
        self.callThreshold = int(os.getenv("CALL_THRESHOLD", 1))
        self.debug = os.getenv("DEBUG", "true").lower() == "true"
        self.reportLatency = os.getenv("REPORT_LATENCY",
                                       "false").lower() == "true"
        self.window_minutes = int(os.getenv("WINDOW_M", 5))
        self.timezone = pytz.timezone(os.getenv("TIMEZONE", "US/Pacific"))
        # The actual look back is the length of this lookback + lag compensation. For example: 300+45=345 seconds
        self.lookback = os.getenv("LOOKBACK_S", 300)

        self.cachedTweet: int = None
        self.cachedTime: datetime = None
        self.cache = TTLCache(maxsize=100, ttl=self.lookback)
        self.scraper = Scraper.Instance(self.BASE_URL, self.lookback)

        self.latency = [timedelta(seconds=0)]

        if not self.debug:
            # Does not need to be saved for later.
            # If the keys aren't in env this will still run.
            auth = tweepy.OAuthHandler(os.getenv("CONSUMER_KEY", ""),
                                       os.getenv("CONSUMER_SECRET", ""))
            auth.set_access_token(os.getenv("ACCESS_TOKEN_KEY", ""),
                                  os.getenv("ACCESS_TOKEN_SECRET", ""))
            self.api = tweepy.API(auth)
            # Test the authentication. This will gracefully fail if the keys aren't present.
            try:
                self.api.rate_limit_status()
            except TweepError as e:
                if e.api_code == 215:
                    log.error("No keys or bad keys")
                else:
                    log.error("Other API error: {}".format(e))
                exit(1)

        self.interval = Set.Interval(30, self._check)

    def _kill(self) -> None:
        """This kills the c̶r̶a̶b̶  bot."""
        self.interval.cancel()
        exit(0)

    def _getUniqueCalls(self, calls: dict) -> list:
        """Filters the return from the scraper to only tweet unique calls.
        Works by checking if the cache already has that call ID.
        Args:
            calls (list): The complete list of calls scraped.
        Returns:
            list: A filtered list of calls.
        """
        res: List[dict] = []
        for call in calls:
            # If the call is already in the cache skip.
            if call["_id"] in self.cache.keys():
                continue
            # If it isn't, cache it and return it.
            else:
                # Might want to actually store somthing? Who knows.
                self.cache.update({call["_id"]: 0})
                res.append(call)
        return res

    def _check(self) -> None:
        """Checks the API and sends a tweet if needed."""
        try:
            log.info(f"Checking!: {datetime.now()}")
            try:
                json = self.scraper.getJSON()
                calls = self._getUniqueCalls(json["calls"])
                log.info(f"Found {len(calls)} calls.")
                if len(calls) > 0:
                    self._postTweet(calls)
            except TypeError as e:
                if json == None:
                    # We already have an error message from the scraper
                    return
                log.exception(e)
                return
        except KeyboardInterrupt as e:
            # Literally impossible to hit which might be an issue? Catching keyboard interrupt could happen in its own thread or something but that sounds complicated 👉👈
            self._kill()

        if self.reportLatency:
            sum = sum(self.latency).total_seconds()
            avg = round(sum / len(self.latency), 3)
            log.info(f"Average latency for the last 100 calls: {avg} seconds")

    def _postTweet(self, calls: list) -> None:
        """Posts a tweet.
        Args:
            calls (list): The call objects to post about.
        """

        # Filter to make sure that calls are actually recent. There can be a weird behavior of the API returning multiple hours old calls all at once. Also filters for calls under the length threshold.
        filteredCalls: List[dict] = []
        for call in calls:
            diff = datetime.now(pytz.utc) - datetime.strptime(
                call["time"], "%Y-%m-%dT%H:%M:%S.000%z")
            if not abs(diff.total_seconds()) >= 1.8e3:
                if call["len"] < self.callThreshold:
                    log.debug(
                        f"Call of size {call['len']} below threshold ({self.callThreshold})"
                    )
                    continue
                filteredCalls.append(call)

                if self.reportLatency:
                    # Store latency
                    self.latency.append(diff)
                    if len(self.latency) > 100:
                        self.latency.pop(0)

        if len(filteredCalls) == 0:
            # If there's nothing to post, simply leave
            return

        msgs = self._generateTweets(filteredCalls)

        if self.debug:
            msg = " | ".join(msgs)
            log.debug(f"Would have posted: {msg}")
            return

        # Check for a cached tweet, then check if the last tweet was less than the window ago. If the window has expired dereference the cached tweet.
        if (self.cachedTime != None
                and self.cachedTime + timedelta(minutes=self.window_minutes) <=
                datetime.now()):
            self.cachedTweet = None

        try:
            if self.cachedTweet != None:
                for msg in msgs:
                    # Every time it posts the new ID gets stored so this works
                    self.cachedTweet = self.api.update_status(
                        msg, self.cachedTweet).id
            else:
                for index, msg in enumerate(msgs):
                    if index == 0:
                        # Since there isn't a cached tweet yet we have to send a non-reply first
                        self.cachedTweet = self.api.update_status(msg).id
                    else:
                        self.cachedTweet = self.api.update_status(
                            msg, self.cachedTweet).id
            self.cachedTime = datetime.now()
        except tweepy.TweepError as e:
            log.exception(e)

    def _timeString(self, call: dict) -> str:
        """Generates a time code string for a call.
        Args:
            call (dict): The call to get time from.
        Returns:
            str: A timestamp string in I:M:S am/pm format.
        """
        # Get time from the call.
        date = datetime.strptime(call["time"], "%Y-%m-%dT%H:%M:%S.000%z")
        # F**k I hate how computer time works
        localized = date.replace(tzinfo=pytz.utc).astimezone(self.timezone)
        normalized = self.timezone.normalize(localized)
        return normalized.strftime("%#I:%M:%S %p")

    def _chunk(self, callStrings: list) -> list:
        """Chunks tweets into an acceptable length.

        Chunking. Shamelessly stolen from `SeattleDSA/signal_scanner_bot/twitter.py` :)

        Args:
            call_strings (list): List of strings derived from calls.

        Returns:
            list: A list of tweet strings to post
        """
        tweetList: List[str] = []
        baseIndex = 0

        # Instead of spliting on words I want to split along call lines.
        subTweet: str = ""
        for index in range(len(callStrings)):
            if len(tweetList) == 0:
                subTweet = (", ".join(callStrings[baseIndex:index]) + " ... " +
                            self.HASHTAGS)
            elif index < len(callStrings):
                subTweet = ", ".join(callStrings[baseIndex:index]) + " ..."
            elif index == len(callStrings):
                subTweet = ", ".join(callStrings[baseIndex:index])

            if len(subTweet) > 280 - self.TWEET_PADDING:
                lastIndex = index - 1
                tweetList.append(", ".join(callStrings[baseIndex:lastIndex]) +
                                 " ...")
                baseIndex = lastIndex

        tweetList.append(", ".join(callStrings[baseIndex:]))
        listLength = len(tweetList)
        for index in range(len(tweetList)):
            if index == 0:
                tweetList[
                    index] += f" {self.HASHTAGS} {index + 1}/{listLength}"
            else:
                tweetList[index] += f" {index + 1}/{listLength}"

        return tweetList

    def _generateTweets(self, calls: list) -> list:
        """Generates tweet messages.
        Args:
            call (list): The calls to tweet about.
        Returns:
            list: The tweet messages, hopefully right around the character limit.
        """
        callStrings: List[str] = []

        # First, take all of the calls and turn them into strings.
        for call in calls:
            callStrings.append(
                self.CALL_TEXT.format(
                    call["len"],
                    self._timeString(call),
                ))

        tweet = ", ".join(callStrings) + " " + self.HASHTAGS
        # If we don't have to chunk we can just leave.
        if len(tweet) <= 280:
            return [tweet]
        else:
            tweetList = self._chunk(callStrings)

        return tweetList
예제 #6
0
class WSNewBlocksWatcher(BaseWatcher):

    MESSAGE_TIMEOUT = 30.0
    PING_TIMEOUT = 10.0

    def __init__(self, w3: Web3, websocket_url):
        super().__init__(w3)
        self._network_on = False
        self._nonce: int = 0
        self._current_block_number: int = -1
        self._websocket_url = websocket_url
        self._node_address = None
        self._client: Optional[websockets.WebSocketClientProtocol] = None
        self._fetch_new_blocks_task: Optional[asyncio.Task] = None
        self._block_cache = TTLCache(maxsize=10, ttl=120)

    _nbw_logger: Optional[HummingbotLogger] = None

    @classmethod
    def logger(cls) -> HummingbotLogger:
        if cls._nbw_logger is None:
            cls._nbw_logger = logging.getLogger(__name__)
        return cls._nbw_logger

    @property
    def block_number(self) -> int:
        return self._current_block_number

    @property
    def block_cache(self) -> Dict[HexBytes, AttributeDict]:
        cache_dict: Dict[HexBytes, AttributeDict] = dict([(key, self._block_cache[key])
                                                          for key in self._block_cache.keys()])
        return cache_dict

    async def start_network(self):
        if self._fetch_new_blocks_task is not None:
            await self.stop_network()
        else:
            try:
                self._current_block_number = await self.call_async(getattr, self._w3.eth, "blockNumber")
            except asyncio.CancelledError:
                raise
            except Exception:
                self.logger().network("Error fetching newest Ethereum block number.",
                                      app_warning_msg="Error fetching newest Ethereum block number. "
                                                      "Check Ethereum node connection",
                                      exc_info=True)
            await self.connect()
            await self.subscribe(["newHeads"])
            self._fetch_new_blocks_task: asyncio.Task = safe_ensure_future(self.fetch_new_blocks_loop())
            self._network_on = True

    async def stop_network(self):
        if self._fetch_new_blocks_task is not None:
            await self.disconnect()
            self._fetch_new_blocks_task.cancel()
            self._fetch_new_blocks_task = None
            self._network_on = False

    async def connect(self):
        try:
            self._client = await websockets.connect(uri=self._websocket_url)
            return self._client
        except Exception as e:
            self.logger().network(f"ERROR in connection: {e}")

    async def disconnect(self):
        try:
            await self._client.close()
        except Exception as e:
            self.logger().network(f"ERROR in disconnection: {e}")

    async def _send(self, emit_data) -> int:
        self._nonce += 1
        emit_data["id"] = self._nonce
        await self._client.send(ujson.dumps(emit_data))
        return self._nonce

    async def subscribe(self, params) -> bool:
        emit_data = {
            "method": "eth_subscribe",
            "params": params
        }
        nonce = await self._send(emit_data)
        raw_message = await self._client.recv()
        if raw_message is not None:
            resp = ujson.loads(raw_message)
            if resp.get("id", None) == nonce:
                self._node_address = resp.get("result")
                return True
        return False

    async def _messages(self) -> AsyncIterable[Any]:
        try:
            while True:
                try:
                    raw_msg_str: str = await asyncio.wait_for(self._client.recv(), self.MESSAGE_TIMEOUT)
                    yield raw_msg_str
                except asyncio.TimeoutError:
                    try:
                        pong_waiter = await self._client.ping()
                        await asyncio.wait_for(pong_waiter, timeout=self.PING_TIMEOUT)
                    except asyncio.TimeoutError:
                        raise
        except asyncio.TimeoutError:
            self.logger().warning("WebSocket ping timed out. Going to reconnect...")
            return
        except ConnectionClosed:
            return
        finally:
            await self.disconnect()

    async def fetch_new_blocks_loop(self):
        while True:
            try:
                async for raw_message in self._messages():
                    message_json = ujson.loads(raw_message) if raw_message is not None else None
                    if message_json.get("method", None) == "eth_subscription":
                        subscription_result_params = message_json.get("params", None)
                        incoming_block = subscription_result_params.get("result", None) \
                            if subscription_result_params is not None else None
                        if incoming_block is not None:
                            new_block: AttributeDict = await self.call_async(self._w3.eth.getBlock,
                                                                             incoming_block.get("hash"), True)
                            self._current_block_number = new_block.get("number")
                            self._block_cache[new_block.get("hash")] = new_block
                            self.trigger_event(NewBlocksWatcherEvent.NewBlocks, [new_block])
            except asyncio.TimeoutError:
                self.logger().network("Timed out fetching new block.", exc_info=True,
                                      app_warning_msg="Timed out fetching new block. "
                                                      "Check wallet network connection")
            except asyncio.CancelledError:
                raise
            except BlockNotFound:
                pass
            except Exception as e:
                self.logger().network(f"Error fetching new block: {e}", exc_info=True,
                                      app_warning_msg="Error fetching new block. "
                                                      "Check wallet network connection")
                await asyncio.sleep(30.0)

    async def get_timestamp_for_block(self, block_hash: HexBytes, max_tries: Optional[int] = 10) -> int:
        counter = 0
        block: AttributeDict = None
        if block_hash in self._block_cache.keys():
            block = self._block_cache.get(block_hash)
        else:
            while block is None:
                if counter == max_tries:
                    raise ValueError(f"Block hash {block_hash.hex()} does not exist.")
                counter += 1
                block = self._block_cache.get(block_hash)
                await asyncio.sleep(0.5)
        return block.get("timestamp")
예제 #7
0
class SQLServer(AgentCheck):
    __NAMESPACE__ = 'sqlserver'

    def __init__(self, name, init_config, instances):
        super(SQLServer, self).__init__(name, init_config, instances)

        self._resolved_hostname = None
        self._agent_hostname = None
        self.connection = None
        self.failed_connections = {}
        self.instance_metrics = []
        self.instance_per_type_metrics = defaultdict(set)
        self.do_check = True

        self.tags = self.instance.get("tags", [])
        self.reported_hostname = self.instance.get('reported_hostname')
        self.autodiscovery = is_affirmative(self.instance.get('database_autodiscovery'))
        self.autodiscovery_include = self.instance.get('autodiscovery_include', ['.*'])
        self.autodiscovery_exclude = self.instance.get('autodiscovery_exclude', [])
        self.autodiscovery_db_service_check = is_affirmative(self.instance.get('autodiscovery_db_service_check', True))
        self.min_collection_interval = self.instance.get('min_collection_interval', 15)
        self._compile_patterns()
        self.autodiscovery_interval = self.instance.get('autodiscovery_interval', DEFAULT_AUTODISCOVERY_INTERVAL)
        self.databases = set()
        self.ad_last_check = 0

        self.proc = self.instance.get('stored_procedure')
        self.proc_type_mapping = {'gauge': self.gauge, 'rate': self.rate, 'histogram': self.histogram}
        self.custom_metrics = init_config.get('custom_metrics', [])

        # DBM
        self.dbm_enabled = self.instance.get('dbm', False)
        self.statement_metrics_config = self.instance.get('query_metrics', {}) or {}
        self.statement_metrics = SqlserverStatementMetrics(self)
        self.activity_config = self.instance.get('query_activity', {}) or {}
        self.activity = SqlserverActivity(self)
        self.cloud_metadata = {}
        aws = self.instance.get('aws', {})
        gcp = self.instance.get('gcp', {})
        azure = self.instance.get('azure', {})
        if aws:
            self.cloud_metadata.update({'aws': aws})
        if gcp:
            self.cloud_metadata.update({'gcp': gcp})
        if azure:
            self.cloud_metadata.update({'azure': azure})
        obfuscator_options_config = self.instance.get('obfuscator_options', {}) or {}
        self.obfuscator_options = to_native_string(
            json.dumps(
                {
                    # Valid values for this can be found at
                    # https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/database.md#connection-level-attributes
                    'dbms': 'mssql',
                    'replace_digits': is_affirmative(
                        obfuscator_options_config.get(
                            'replace_digits',
                            obfuscator_options_config.get('quantize_sql_tables', False),
                        )
                    ),
                    'keep_sql_alias': is_affirmative(obfuscator_options_config.get('keep_sql_alias', True)),
                    'return_json_metadata': is_affirmative(obfuscator_options_config.get('collect_metadata', True)),
                    'table_names': is_affirmative(obfuscator_options_config.get('collect_tables', True)),
                    'collect_commands': is_affirmative(obfuscator_options_config.get('collect_commands', True)),
                    'collect_comments': is_affirmative(obfuscator_options_config.get('collect_comments', True)),
                }
            )
        )

        self.static_info_cache = TTLCache(
            maxsize=100,
            # cache these for a full day
            ttl=60 * 60 * 24,
        )

        # Query declarations
        check_queries = []
        if is_affirmative(self.instance.get('include_ao_metrics', False)):
            check_queries.extend(
                [
                    QUERY_AO_AVAILABILITY_GROUPS,
                    QUERY_AO_FAILOVER_CLUSTER,
                    QUERY_AO_FAILOVER_CLUSTER_MEMBER,
                ]
            )
        if is_affirmative(self.instance.get('include_fci_metrics', False)):
            check_queries.extend([QUERY_FAILOVER_CLUSTER_INSTANCE])
        self._check_queries = self._new_query_executor(check_queries)
        self.check_initializations.append(self._check_queries.compile_queries)

        self.server_state_queries = self._new_query_executor([QUERY_SERVER_STATIC_INFO])
        self.check_initializations.append(self.server_state_queries.compile_queries)

        # use QueryManager to process custom queries
        self._query_manager = QueryManager(
            self, self.execute_query_raw, tags=self.tags, hostname=self.resolved_hostname
        )

        self._dynamic_queries = None

        self.check_initializations.append(self.config_checks)
        self.check_initializations.append(self._query_manager.compile_queries)
        self.check_initializations.append(self.initialize_connection)

    def cancel(self):
        self.statement_metrics.cancel()
        self.activity.cancel()

    def config_checks(self):
        if self.autodiscovery and self.instance.get('database'):
            self.log.warning(
                'sqlserver `database_autodiscovery` and `database` options defined in same instance - '
                'autodiscovery will take precedence.'
            )
        if not self.autodiscovery and (self.autodiscovery_include or self.autodiscovery_exclude):
            self.log.warning(
                "Autodiscovery is disabled, autodiscovery_include and autodiscovery_exclude will be ignored"
            )

    def split_sqlserver_host_port(self, host):
        """
        Splits the host & port out of the provided SQL Server host connection string, returning (host, port).
        """
        if not host:
            return host, None
        host_split = [s.strip() for s in host.split(',')]
        if len(host_split) == 1:
            return host_split[0], None
        if len(host_split) == 2:
            return host_split
        # else len > 2
        s_host, s_port = host_split[0:2]
        self.log.warning(
            "invalid sqlserver host string has more than one comma: %s. using only 1st two items: host:%s, port:%s",
            host,
            s_host,
            s_port,
        )
        return s_host, s_port

    def _new_query_executor(self, queries):
        return QueryExecutor(
            self.execute_query_raw,
            self,
            queries=queries,
            tags=self.tags,
            hostname=self.resolved_hostname,
        )

    @property
    def resolved_hostname(self):
        if self._resolved_hostname is None:
            if self.reported_hostname:
                self._resolved_hostname = self.reported_hostname
            elif self.dbm_enabled:
                host, port = self.split_sqlserver_host_port(self.instance.get('host'))
                self._resolved_hostname = resolve_db_host(host)
            else:
                self._resolved_hostname = self.agent_hostname
        return self._resolved_hostname

    def load_static_information(self):
        expected_keys = {STATIC_INFO_VERSION, STATIC_INFO_MAJOR_VERSION, STATIC_INFO_ENGINE_EDITION}
        missing_keys = expected_keys - set(self.static_info_cache.keys())
        if missing_keys:
            with self.connection.open_managed_default_connection():
                with self.connection.get_managed_cursor() as cursor:
                    if STATIC_INFO_VERSION not in self.static_info_cache:
                        cursor.execute("select @@version")
                        results = cursor.fetchall()
                        if results and len(results) > 0 and len(results[0]) > 0 and results[0][0]:
                            version = results[0][0]
                            self.static_info_cache[STATIC_INFO_VERSION] = version
                            self.static_info_cache[STATIC_INFO_MAJOR_VERSION] = parse_sqlserver_major_version(version)
                            if not self.static_info_cache[STATIC_INFO_MAJOR_VERSION]:
                                self.log.warning("failed to parse SQL Server major version from version: %s", version)
                        else:
                            self.log.warning("failed to load version static information due to empty results")
                    if STATIC_INFO_ENGINE_EDITION not in self.static_info_cache:
                        cursor.execute("SELECT CAST(ServerProperty('EngineEdition') AS INT) AS Edition")
                        result = cursor.fetchone()
                        if result:
                            self.static_info_cache[STATIC_INFO_ENGINE_EDITION] = result
                        else:
                            self.log.warning("failed to load version static information due to empty results")

    def debug_tags(self):
        return self.tags + ['agent_hostname:{}'.format(self.agent_hostname)]

    def debug_stats_kwargs(self, tags=None):
        tags = tags if tags else []
        return {
            "tags": self.debug_tags() + tags,
            "hostname": self.resolved_hostname,
            "raw": True,
        }

    @property
    def agent_hostname(self):
        # type: () -> str
        if self._agent_hostname is None:
            self._agent_hostname = datadog_agent.get_hostname()
        return self._agent_hostname

    def initialize_connection(self):
        self.connection = Connection(self.init_config, self.instance, self.handle_service_check)

        # Pre-process the list of metrics to collect
        try:
            # check to see if the database exists before we try any connections to it
            db_exists, context = self.connection.check_database()

            if db_exists:
                if self.instance.get('stored_procedure') is None:
                    with self.connection.open_managed_default_connection():
                        with self.connection.get_managed_cursor() as cursor:
                            self.autodiscover_databases(cursor)
                        self._make_metric_list_to_collect(self.custom_metrics)
            else:
                # How much do we care that the DB doesn't exist?
                ignore = is_affirmative(self.instance.get("ignore_missing_database", False))
                if ignore is not None and ignore:
                    # not much : we expect it. leave checks disabled
                    self.do_check = False
                    self.log.warning("Database %s does not exist. Disabling checks for this instance.", context)
                else:
                    # yes we do. Keep trying
                    msg = "Database {} does not exist. Please resolve invalid database and restart agent".format(
                        context
                    )
                    raise ConfigurationError(msg)

        except SQLConnectionError as e:
            self.log.exception("Error connecting to database: %s", e)
        except ConfigurationError:
            raise
        except Exception as e:
            self.log.exception("Initialization exception %s", e)

    def handle_service_check(self, status, host, database, message=None, is_default=True):
        custom_tags = self.instance.get("tags", [])
        disable_generic_tags = self.instance.get('disable_generic_tags', False)
        if custom_tags is None:
            custom_tags = []
        if disable_generic_tags:
            service_check_tags = ['sqlserver_host:{}'.format(host), 'db:{}'.format(database)]
        else:
            service_check_tags = ['host:{}'.format(host), 'sqlserver_host:{}'.format(host), 'db:{}'.format(database)]
        service_check_tags.extend(custom_tags)
        service_check_tags = list(set(service_check_tags))

        if status is AgentCheck.OK:
            message = None

        if is_default:
            self.service_check(SERVICE_CHECK_NAME, status, tags=service_check_tags, message=message, raw=True)
        if self.autodiscovery and self.autodiscovery_db_service_check:
            self.service_check(DATABASE_SERVICE_CHECK_NAME, status, tags=service_check_tags, message=message, raw=True)

    def _compile_patterns(self):
        self._include_patterns = self._compile_valid_patterns(self.autodiscovery_include)
        self._exclude_patterns = self._compile_valid_patterns(self.autodiscovery_exclude)

    def _compile_valid_patterns(self, patterns):
        valid_patterns = []

        for pattern in patterns:
            # Ignore empty patterns as they match everything
            if not pattern:
                continue

            try:
                re.compile(pattern, re.IGNORECASE)
            except Exception:
                self.log.warning('%s is not a valid regular expression and will be ignored', pattern)
            else:
                valid_patterns.append(pattern)

        if valid_patterns:
            return re.compile('|'.join(valid_patterns), re.IGNORECASE)
        else:
            # create unmatchable regex - https://stackoverflow.com/a/1845097/2157429
            return re.compile(r'(?!x)x')

    def autodiscover_databases(self, cursor):
        if not self.autodiscovery:
            return False

        now = time.time()
        if now - self.ad_last_check > self.autodiscovery_interval:
            self.log.info('Performing database autodiscovery')
            cursor.execute(AUTODISCOVERY_QUERY)
            all_dbs = set(row.name for row in cursor.fetchall())
            excluded_dbs = set([d for d in all_dbs if self._exclude_patterns.match(d)])
            included_dbs = set([d for d in all_dbs if self._include_patterns.match(d)])

            self.log.debug(
                'Autodiscovered databases: %s, excluding: %s, including: %s', all_dbs, excluded_dbs, included_dbs
            )

            # keep included dbs but remove any that were explicitly excluded
            filtered_dbs = all_dbs.intersection(included_dbs) - excluded_dbs

            self.log.debug('Resulting filtered databases: %s', filtered_dbs)
            self.ad_last_check = now

            if filtered_dbs != self.databases:
                self.log.debug('Databases updated from previous autodiscovery check.')
                self.databases = filtered_dbs
                return True
        return False

    def _make_metric_list_to_collect(self, custom_metrics):
        """
        Store the list of metrics to collect by instance_key.
        Will also create and cache cursors to query the db.
        """

        metrics_to_collect = []
        tags = self.instance.get('tags', [])

        # Load instance-level (previously Performance) metrics)
        # If several check instances are querying the same server host, it can be wise to turn these off
        # to avoid sending duplicate metrics
        if is_affirmative(self.instance.get('include_instance_metrics', True)):
            common_metrics = INSTANCE_METRICS
            if not self.dbm_enabled:
                common_metrics = common_metrics + DBM_MIGRATED_METRICS

            self._add_performance_counters(
                chain(common_metrics, INSTANCE_METRICS_TOTAL), metrics_to_collect, tags, db=None
            )

        # populated through autodiscovery
        if self.databases:
            for db in self.databases:
                self._add_performance_counters(INSTANCE_METRICS_TOTAL, metrics_to_collect, tags, db=db)

        # Load database statistics
        for name, table, column in DATABASE_METRICS:
            # include database as a filter option
            db_names = self.databases or [self.instance.get('database', self.connection.DEFAULT_DATABASE)]
            for db_name in db_names:
                cfg = {'name': name, 'table': table, 'column': column, 'instance_name': db_name, 'tags': tags}
                metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column))

        # Load AlwaysOn metrics
        if is_affirmative(self.instance.get('include_ao_metrics', False)):
            for name, table, column in AO_METRICS + AO_METRICS_PRIMARY + AO_METRICS_SECONDARY:
                db_name = 'master'
                cfg = {
                    'name': name,
                    'table': table,
                    'column': column,
                    'instance_name': db_name,
                    'tags': tags,
                    'ao_database': self.instance.get('ao_database', None),
                    'availability_group': self.instance.get('availability_group', None),
                    'only_emit_local': is_affirmative(self.instance.get('only_emit_local', False)),
                }
                metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column))

        # Load metrics from scheduler and task tables, if enabled
        if is_affirmative(self.instance.get('include_task_scheduler_metrics', False)):
            for name, table, column in TASK_SCHEDULER_METRICS:
                cfg = {'name': name, 'table': table, 'column': column, 'tags': tags}
                metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column))

        # Load sys.master_files metrics
        if is_affirmative(self.instance.get('include_master_files_metrics', False)):
            for name, table, column in DATABASE_MASTER_FILES:
                cfg = {'name': name, 'table': table, 'column': column, 'tags': tags}
                metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column))

        # Load DB Fragmentation metrics
        if is_affirmative(self.instance.get('include_db_fragmentation_metrics', False)):
            db_fragmentation_object_names = self.instance.get('db_fragmentation_object_names', [])
            db_names = self.databases or [self.instance.get('database', self.connection.DEFAULT_DATABASE)]

            if not db_fragmentation_object_names:
                self.log.debug(
                    "No fragmentation object names specified, will return fragmentation metrics for all "
                    "object_ids of current database(s): %s",
                    db_names,
                )

            for db_name in db_names:
                for name, table, column in DATABASE_FRAGMENTATION_METRICS:
                    cfg = {
                        'name': name,
                        'table': table,
                        'column': column,
                        'instance_name': db_name,
                        'tags': tags,
                        'db_fragmentation_object_names': db_fragmentation_object_names,
                    }
                    metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column))

        # Load any custom metrics from conf.d/sqlserver.yaml
        for cfg in custom_metrics:
            sql_type = None
            base_name = None

            custom_tags = tags + cfg.get('tags', [])
            cfg['tags'] = custom_tags

            db_table = cfg.get('table', DEFAULT_PERFORMANCE_TABLE)
            if db_table not in VALID_TABLES:
                self.log.error('%s has an invalid table name: %s', cfg['name'], db_table)
                continue

            if cfg.get('database', None) and cfg.get('database') != self.instance.get('database'):
                self.log.debug(
                    'Skipping custom metric %s for database %s, check instance configured for database %s',
                    cfg['name'],
                    cfg.get('database'),
                    self.instance.get('database'),
                )
                continue

            if db_table == DEFAULT_PERFORMANCE_TABLE:
                user_type = cfg.get('type')
                if user_type is not None and user_type not in VALID_METRIC_TYPES:
                    self.log.error('%s has an invalid metric type: %s', cfg['name'], user_type)
                sql_type = None
                try:
                    if user_type is None:
                        sql_type, base_name = self.get_sql_type(cfg['counter_name'])
                except Exception:
                    self.log.warning("Can't load the metric %s, ignoring", cfg['name'], exc_info=True)
                    continue

                metrics_to_collect.append(
                    self.typed_metric(
                        cfg_inst=cfg, table=db_table, base_name=base_name, user_type=user_type, sql_type=sql_type
                    )
                )

            else:
                for column in cfg['columns']:
                    metrics_to_collect.append(
                        self.typed_metric(
                            cfg_inst=cfg, table=db_table, base_name=base_name, sql_type=sql_type, column=column
                        )
                    )

        self.instance_metrics = metrics_to_collect
        self.log.debug("metrics to collect %s", metrics_to_collect)

        # create an organized grouping of metric names to their metric classes
        for m in metrics_to_collect:
            cls = m.__class__.__name__
            name = m.sql_name or m.column
            self.log.debug("Adding metric class %s named %s", cls, name)

            self.instance_per_type_metrics[cls].add(name)
            if m.base_name:
                self.instance_per_type_metrics[cls].add(m.base_name)

    def _add_performance_counters(self, metrics, metrics_to_collect, tags, db=None):
        if db is not None:
            tags = tags + ['database:{}'.format(db)]
        for name, counter_name, instance_name in metrics:
            try:
                sql_type, base_name = self.get_sql_type(counter_name)
                cfg = {
                    'name': name,
                    'counter_name': counter_name,
                    'instance_name': db or instance_name,
                    'tags': tags,
                }

                metrics_to_collect.append(
                    self.typed_metric(
                        cfg_inst=cfg, table=DEFAULT_PERFORMANCE_TABLE, base_name=base_name, sql_type=sql_type
                    )
                )
            except SQLConnectionError:
                raise
            except Exception:
                self.log.warning("Can't load the metric %s, ignoring", name, exc_info=True)
                continue

    def get_sql_type(self, counter_name):
        """
        Return the type of the performance counter so that we can report it to
        Datadog correctly
        If the sql_type is one that needs a base (PERF_RAW_LARGE_FRACTION and
        PERF_AVERAGE_BULK), the name of the base counter will also be returned
        """
        with self.connection.get_managed_cursor() as cursor:
            cursor.execute(COUNTER_TYPE_QUERY, (counter_name,))
            (sql_type,) = cursor.fetchone()
            if sql_type == PERF_LARGE_RAW_BASE:
                self.log.warning("Metric %s is of type Base and shouldn't be reported this way", counter_name)
            base_name = None
            if sql_type in [PERF_AVERAGE_BULK, PERF_RAW_LARGE_FRACTION]:
                # This is an ugly hack. For certains type of metric (PERF_RAW_LARGE_FRACTION
                # and PERF_AVERAGE_BULK), we need two metrics: the metrics specified and
                # a base metrics to get the ratio. There is no unique schema so we generate
                # the possible candidates and we look at which ones exist in the db.
                candidates = (
                    counter_name + " base",
                    counter_name.replace("(ms)", "base"),
                    counter_name.replace("Avg ", "") + " base",
                )
                try:
                    cursor.execute(BASE_NAME_QUERY, candidates)
                    base_name = cursor.fetchone().counter_name.strip()
                    self.log.debug("Got base metric: %s for metric: %s", base_name, counter_name)
                except Exception as e:
                    self.log.warning("Could not get counter_name of base for metric: %s", e)

        return sql_type, base_name

    def typed_metric(self, cfg_inst, table, base_name=None, user_type=None, sql_type=None, column=None):
        """
        Create the appropriate BaseSqlServerMetric object, each implementing its method to
        fetch the metrics properly.
        If a `type` was specified in the config, it is used to report the value
        directly fetched from SQLServer. Otherwise, it is decided based on the
        sql_type, according to microsoft's documentation.
        """
        if table == DEFAULT_PERFORMANCE_TABLE:
            metric_type_mapping = {
                PERF_COUNTER_BULK_COUNT: (self.rate, metrics.SqlSimpleMetric),
                PERF_COUNTER_LARGE_RAWCOUNT: (self.gauge, metrics.SqlSimpleMetric),
                PERF_LARGE_RAW_BASE: (self.gauge, metrics.SqlSimpleMetric),
                PERF_RAW_LARGE_FRACTION: (self.gauge, metrics.SqlFractionMetric),
                PERF_AVERAGE_BULK: (self.gauge, metrics.SqlIncrFractionMetric),
            }
            if user_type is not None:
                # user type overrides any other value
                metric_type = getattr(self, user_type)
                cls = metrics.SqlSimpleMetric

            else:
                metric_type, cls = metric_type_mapping[sql_type]
        else:
            # Lookup metrics classes by their associated table
            metric_type_str, cls = metrics.TABLE_MAPPING[table]
            metric_type = getattr(self, metric_type_str)

        cfg_inst['hostname'] = self.resolved_hostname

        return cls(cfg_inst, base_name, metric_type, column, self.log)

    def check(self, _):
        if self.do_check:
            self.load_static_information()
            if self.proc:
                self.do_stored_procedure_check()
            else:
                self.collect_metrics()
            if self.autodiscovery and self.autodiscovery_db_service_check:
                for db_name in self.databases:
                    if db_name != self.connection.DEFAULT_DATABASE:
                        try:
                            self.connection.check_database_conns(db_name)
                        except Exception as e:
                            # service_check errors on auto discovered databases should not abort the check
                            self.log.warning("failed service check for auto discovered database: %s", e)

            if self.dbm_enabled:
                self.statement_metrics.run_job_loop(self.tags)
                self.activity.run_job_loop(self.tags)

        else:
            self.log.debug("Skipping check")

    @property
    def dynamic_queries(self):
        """
        Initializes dynamic queries which depend on static information loaded from the database
        """
        if self._dynamic_queries:
            return self._dynamic_queries

        major_version = self.static_info_cache.get(STATIC_INFO_MAJOR_VERSION)
        if not major_version:
            self.log.warning("missing major_version, cannot initialize dynamic queries")
            return None

        queries = [get_query_file_stats(major_version)]
        self._dynamic_queries = self._new_query_executor(queries)
        self._dynamic_queries.compile_queries()
        self.log.debug("initialized dynamic queries")
        return self._dynamic_queries

    def collect_metrics(self):
        """Fetch the metrics from all of the associated database tables."""

        with self.connection.open_managed_default_connection():
            with self.connection.get_managed_cursor() as cursor:
                # initiate autodiscovery or if the server was down at check __init__ key could be missing.
                if self.autodiscover_databases(cursor) or not self.instance_metrics:
                    self._make_metric_list_to_collect(self.custom_metrics)

                instance_results = {}

                # Execute the `fetch_all` operations first to minimize the database calls
                for cls, metric_names in six.iteritems(self.instance_per_type_metrics):
                    if not metric_names:
                        instance_results[cls] = None, None
                    else:
                        try:
                            db_names = self.databases or [
                                self.instance.get('database', self.connection.DEFAULT_DATABASE)
                            ]
                            rows, cols = getattr(metrics, cls).fetch_all_values(
                                cursor, list(metric_names), self.log, databases=db_names
                            )
                        except Exception as e:
                            self.log.error("Error running `fetch_all` for metrics %s - skipping.  Error: %s", cls, e)
                            rows, cols = None, None

                        instance_results[cls] = rows, cols

                # Using the cached data, extract and report individual metrics
                for metric in self.instance_metrics:
                    if type(metric) is metrics.SqlIncrFractionMetric:
                        # special case, since it uses the same results as SqlFractionMetric
                        key = 'SqlFractionMetric'
                    else:
                        key = metric.__class__.__name__

                    if key not in instance_results:
                        self.log.warning("No %s metrics found, skipping", str(key))
                    else:
                        rows, cols = instance_results[key]
                        if rows is not None:
                            metric.fetch_metric(rows, cols)

            # Neither pyodbc nor adodbapi are able to read results of a query if the number of rows affected
            # statement are returned as part of the result set, so we disable for the entire connection
            # this is important mostly for custom_queries or the stored_procedure feature
            # https://docs.microsoft.com/en-us/sql/t-sql/statements/set-nocount-transact-sql
            with self.connection.get_managed_cursor() as cursor:
                cursor.execute("SET NOCOUNT ON")
            try:
                # Server state queries require VIEW SERVER STATE permissions, which some managed database
                # versions do not support.
                if self.static_info_cache.get(STATIC_INFO_ENGINE_EDITION) not in [
                    ENGINE_EDITION_SQL_DATABASE,
                ]:
                    self.server_state_queries.execute()

                self._check_queries.execute()
                if self.dynamic_queries:
                    self.dynamic_queries.execute()
                # reuse connection for any custom queries
                self._query_manager.execute()
            finally:
                with self.connection.get_managed_cursor() as cursor:
                    cursor.execute("SET NOCOUNT OFF")

    def execute_query_raw(self, query):
        with self.connection.get_managed_cursor() as cursor:
            cursor.execute(query)
            return cursor.fetchall()

    def do_stored_procedure_check(self):
        """
        Fetch the metrics from the stored proc
        """

        proc = self.proc
        guardSql = self.instance.get('proc_only_if')
        custom_tags = self.instance.get("tags", [])

        if (guardSql and self.proc_check_guard(guardSql)) or not guardSql:
            self.connection.open_db_connections(self.connection.DEFAULT_DB_KEY)
            cursor = self.connection.get_cursor(self.connection.DEFAULT_DB_KEY)

            try:
                self.log.debug("Calling Stored Procedure : %s", proc)
                if self.connection.get_connector() == 'adodbapi':
                    cursor.callproc(proc)
                else:
                    # pyodbc does not support callproc; use execute instead.
                    # Reference: https://github.com/mkleehammer/pyodbc/wiki/Calling-Stored-Procedures
                    call_proc = '{{CALL {}}}'.format(proc)
                    cursor.execute(call_proc)

                rows = cursor.fetchall()
                self.log.debug("Row count (%s) : %s", proc, cursor.rowcount)

                for row in rows:
                    tags = [] if row.tags is None or row.tags == '' else row.tags.split(',')
                    tags.extend(custom_tags)

                    if row.type.lower() in self.proc_type_mapping:
                        self.proc_type_mapping[row.type](row.metric, row.value, tags, raw=True)
                    else:
                        self.log.warning(
                            '%s is not a recognised type from procedure %s, metric %s', row.type, proc, row.metric
                        )

            except Exception as e:
                self.log.warning("Could not call procedure %s: %s", proc, e)
                raise e

            self.connection.close_cursor(cursor)
            self.connection.close_db_connections(self.connection.DEFAULT_DB_KEY)
        else:
            self.log.info("Skipping call to %s due to only_if", proc)

    def proc_check_guard(self, sql):
        """
        check to see if the guard SQL returns a single column containing 0 or 1
        We return true if 1, else False
        """
        self.connection.open_db_connections(self.connection.PROC_GUARD_DB_KEY)
        cursor = self.connection.get_cursor(self.connection.PROC_GUARD_DB_KEY)

        should_run = False
        try:
            cursor.execute(sql, ())
            result = cursor.fetchone()
            should_run = result[0] == 1
        except Exception as e:
            self.log.error("Failed to run proc_only_if sql %s : %s", sql, e)

        self.connection.close_cursor(cursor)
        self.connection.close_db_connections(self.connection.PROC_GUARD_DB_KEY)
        return should_run
예제 #8
0
class Shotachan:
    def __init__(self, bot):
        self.bot = bot
        self.session = requests.Session()
        self.cache = TTLCache(maxsize=500, ttl=300)
        self.count_url = "http://booru.shotachan.net/post?tags="
        self.post_url = "http://booru.shotachan.net/post/index.json?tags="

    @commands.command(name='shota', pass_context=True, no_pm=True)
    async def shotachan_search(self, ctx, *args):
        """: !shota <tags>      | Post a random image from the Shotachan Booru."""

        # if no tags, default
        message = utils.parse_message(ctx.message.content)

        # get number of pages
        maxpage = self.shotachan_count(message)

        # Choose random page number
        pid = list(range(0, int(maxpage)))
        random.shuffle(pid)

        # get posts
        url = self.post_url + message + "&page=" + str(pid[0])
        posts = json.loads(self.session.get(url).text)
        count = len(posts)

        # get random post
        if posts:
            post = await self.getRandomPost(posts)

        else:
            msg = "No Results Found."
            await self.bot.send_message(ctx.message.channel, msg)
            return

        # if random post returned, create embed.
        if post:
            embed = self.create_embed(post)
            await self.bot.send_message(ctx.message.channel, embed=embed)

        else:
            msg = "All images have already been seen. Try again later."
            await self.bot.send_message(ctx.message.channel, msg)
            return

    @lru_cache(maxsize=None)
    def shotachan_count(self, message):

        url = self.count_url + message
        r = self.session.get(url)
        try:
            # Get total number of posts.
            strainer = SoupStrainer('div', attrs={'id': 'paginator'})
            soup = BeautifulSoup(r.content, "html.parser", parse_only=strainer)
            maxpage = soup.find_all('a')[-2].getText()

        except:
            return 1

        return maxpage

    async def getRandomPost(self, posts):
        # create list of posts not recently seen.
        postlist = []

        for post in posts:
            if post['id'] not in self.cache.keys():
                postlist.append(post)

        # if post list has posts, shuffle and select one
        if len(postlist):
            random.shuffle(postlist)
            postid = postlist[0]['id']
            self.cache[postid] = postid
            return postlist[0]

        else:
            return None

    def create_embed(self, post):
        postid = str(post['id'])
        post_url = str(post['file_url'])
        source_url = str(post['source'])
        orig_url = "http://booru.shotachan.net/post/show/{id}".format(
            id=postid)

        embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA)
        embed.set_image(url=post_url)
        embed.set_author(name="Shotachan",
                         icon_url="http://booru.shotachan.net/favicon.ico")
        embed.add_field(
            name="\u200B",
            value="*View on:* [[Shotachan]]({o})  |  [[Source]]({s})".format(
                o=orig_url, s=source_url),
            inline=False)
        return embed
예제 #9
0
class MemoCache(Cache):
    """Manages cached values for a single st.memo-ized function."""
    def __init__(
        self,
        key: str,
        persist: Optional[str],
        max_entries: float,
        ttl: float,
        display_name: str,
    ):
        self.key = key
        self.display_name = display_name
        self.persist = persist
        self._mem_cache = TTLCache(maxsize=max_entries,
                                   ttl=ttl,
                                   timer=_TTLCACHE_TIMER)
        self._mem_cache_lock = threading.Lock()

    @property
    def max_entries(self) -> float:
        return cast(float, self._mem_cache.maxsize)

    @property
    def ttl(self) -> float:
        return cast(float, self._mem_cache.ttl)

    def get_stats(self) -> List[CacheStat]:
        stats: List[CacheStat] = []
        with self._mem_cache_lock:
            for item_key, item_value in self._mem_cache.items():
                stats.append(
                    CacheStat(
                        category_name="st_memo",
                        cache_name=self.display_name,
                        byte_length=len(item_value),
                    ))
        return stats

    def read_value(self, key: str) -> Any:
        """Read a value from the cache. Raise `CacheKeyNotFoundError` if the
        value doesn't exist, and `CacheError` if the value exists but can't
        be unpickled.
        """
        try:
            pickled_value = self._read_from_mem_cache(key)

        except CacheKeyNotFoundError as e:
            if self.persist == "disk":
                pickled_value = self._read_from_disk_cache(key)
                self._write_to_mem_cache(key, pickled_value)
            else:
                raise e

        try:
            return pickle.loads(pickled_value)
        except pickle.UnpicklingError as exc:
            raise CacheError(f"Failed to unpickle {key}") from exc

    def write_value(self, key: str, value: Any) -> None:
        """Write a value to the cache. It must be pickleable."""
        try:
            pickled_value = pickle.dumps(value)
        except pickle.PicklingError as exc:
            raise CacheError(f"Failed to pickle {key}") from exc

        self._write_to_mem_cache(key, pickled_value)
        if self.persist == "disk":
            self._write_to_disk_cache(key, pickled_value)

    def clear(self) -> None:
        with self._mem_cache_lock:
            # We keep a lock for the entirety of the clear operation to avoid
            # disk cache race conditions.
            for key in self._mem_cache.keys():
                self._remove_from_disk_cache(key)

            self._mem_cache.clear()

    def _read_from_mem_cache(self, key: str) -> bytes:
        with self._mem_cache_lock:
            if key in self._mem_cache:
                entry = bytes(self._mem_cache[key])
                _LOGGER.debug("Memory cache HIT: %s", key)
                return entry

            else:
                _LOGGER.debug("Memory cache MISS: %s", key)
                raise CacheKeyNotFoundError("Key not found in mem cache")

    def _read_from_disk_cache(self, key: str) -> bytes:
        path = self._get_file_path(key)
        try:
            with streamlit_read(path, binary=True) as input:
                value = input.read()
                _LOGGER.debug("Disk cache HIT: %s", key)
                return bytes(value)
        except FileNotFoundError:
            raise CacheKeyNotFoundError("Key not found in disk cache")
        except BaseException as e:
            _LOGGER.error(e)
            raise CacheError("Unable to read from cache") from e

    def _write_to_mem_cache(self, key: str, pickled_value: bytes) -> None:
        with self._mem_cache_lock:
            self._mem_cache[key] = pickled_value

    def _write_to_disk_cache(self, key: str, pickled_value: bytes) -> None:
        path = self._get_file_path(key)
        try:
            with streamlit_write(path, binary=True) as output:
                output.write(pickled_value)
        except util.Error as e:
            _LOGGER.debug(e)
            # Clean up file so we don't leave zero byte files.
            try:
                os.remove(path)
            except (FileNotFoundError, IOError, OSError):
                pass
            raise CacheError("Unable to write to cache") from e

    def _remove_from_disk_cache(self, key: str) -> None:
        """Delete a cache file from disk. If the file does not exist on disk,
        return silently. If another exception occurs, log it. Does not throw.
        """
        path = self._get_file_path(key)
        try:
            os.remove(path)
        except FileNotFoundError:
            pass
        except BaseException as e:
            _LOGGER.exception("Unable to remove a file from the disk cache", e)

    def _get_file_path(self, value_key: str) -> str:
        """Return the path of the disk cache file for the given value."""
        return get_streamlit_file_path(_CACHE_DIR_NAME,
                                       f"{self.key}-{value_key}.memo")
예제 #10
0
class Gelbooru:
    def __init__(self, bot):
        self.cache = TTLCache(maxsize=500, ttl=300)
        self.bot = bot
        self.session = requests.Session()

        self.url = "http://gelbooru.com/index.php?page=dapi&s=post&q=index&tags="
        self.post_url = "http://gelbooru.com/index.php?page=post&s=view&id="

        self.default_tags = None
        self.tags_file = "./config/gelbooru_defaulttags.yaml"
        self.loadDefaultTags()

    @commands.group(name='cum',
                    pass_context=True,
                    no_pm=True,
                    invoke_without_command=True)
    async def gelbooru_search(self, ctx, *args):
        """: !cum <tags>        | Post a random image from Gelboorur"""

        channel = str(ctx.message.channel)

        if ctx.invoked_subcommand is None:
            message = utils.parse_message(ctx.message.content,
                                          tags=self.default_tags[channel])
            count = self.gelbooru_count(message)

            if count:
                pageid = self.getRandomPage(count)
            else:
                msg = "No Results Found."
                await self.bot.send_message(ctx.message.channel, msg)
                return

            # Search
            r = self.session.get(self.url + message + "&pid=" + str(pageid))
            soup = BeautifulSoup(r.content, "html.parser")
            posts = soup.find_all("post")
            post = await self.getRandomPost(posts, count)

            if post:
                embed = self.create_embed(post)
                await self.bot.send_message(ctx.message.channel, embed=embed)
            else:

                msg = "All images already seen, Try again later."
                await self.bot.send_message(ctx.message.channel, msg)
                return

    @gelbooru_search.group(name="default_tags",
                           pass_context=True,
                           no_pm=True,
                           invoke_without_command=True)
    async def defaultTags(self, ctx):
        if ctx.invoked_subcommand is None:
            channel = str(ctx.message.channel)
            self.checkDefaultTags(channel)
            await self.bot.send_message(ctx.message.channel,
                                        self.default_tags[channel])

    @defaultTags.command(name="add", pass_context=True, no_pm=True)
    async def blacklist_add(self, ctx):
        #if str(ctx.message.author.id) not in admins:
        #   return

        admins = [a for a in self.bot.session.query(db.Admin.userid).all()]
        admins = [
            a.userid for a in self.bot.session.query(db.Admin.userid).all()
        ]
        if str(ctx.message.author.id) not in admins:
            return

        msg = ctx.message.content.split(' ')[3:]
        channel = str(ctx.message.channel)
        self.checkDefaultTags(channel)
        writefile = False

        for tag in msg:
            if tag not in self.default_tags[channel]:
                self.default_tags[channel].append(tag)
                writefile = True

        if writefile:
            with open(self.tags_file, 'w') as tagfile:
                yaml.dump(self.default_tags, tagfile, default_flow_style=False)

            await self.bot.send_message(ctx.message.channel,
                                        self.default_tags[channel])

    @defaultTags.command(name="remove", pass_context=True, no_pm=True)
    async def blacklist_remove(self, ctx):

        admins = [a for a in self.bot.session.query(db.Admin.userid).all()]
        admins = [
            a.userid for a in self.bot.session.query(db.Admin.userid).all()
        ]
        if str(ctx.message.author.id) not in admins:
            return

        msg = ctx.message.content.split(' ')[3:]
        channel = str(ctx.message.channel)
        self.checkDefaultTags(channel)
        writefile = False

        for tag in msg:
            if tag in self.default_tags[channel]:
                self.default_tags[channel].remove(tag)
                writefile = True

        if writefile:
            with open(self.default_tags, 'w') as tagfile:
                yaml.dump(self.default_tags, tagfile, default_flow_style=False)

            await self.bot.send_message(ctx.message.channel,
                                        self.default_tags[channel])

    def loadDefaultTags(self):
        if Path(self.tags_file).is_file():
            with open(self.tags_file, 'r') as tagfile:
                self.default_tags = yaml.load(tagfile)
        else:
            open(self.tags_file, 'x')

        if self.default_tags is None:
            self.default_tags = dict()

    def checkDefaultTags(self, channel):
        if channel not in self.default_tags:
            self.default_tags[channel] = ['shota']

    async def getRandomPost(self, posts, count):
        # create list of posts not recently seen.
        postlist = []
        for post in posts:
            if post['id'] not in self.cache.keys():
                postlist.append(post)

        # if post list has posts, shuffle and select one
        if len(postlist):
            random.shuffle(postlist)
            postid = postlist[0]['id']
            self.cache[postid] = postid

            return postlist[0]

    def getRandomPage(self, count):
        maxpage = int(round(count / 100))
        if maxpage < 1:
            maxpage = 1

        # return random page number from range(0, maxpage)
        pageid = random.sample(list(range(0, maxpage)), 1)[0]
        return pageid

    @lru_cache(maxsize=None)
    def gelbooru_count(self, message):
        # get total number of posts
        r = self.session.get(self.url + message)

        if r.status_code == 200:
            soup = BeautifulSoup(r.content, "html.parser")
            count = int(soup.find("posts")['count'])

            if count:
                return count

    def create_embed(self, post):
        # create discord embed from post
        postid = str(post['id'])
        post_url = "https:" + str(post['file_url'])
        source_url = str(post['source'])
        orig_url = self.post_url + postid
        source_text = """*View on:* [[Gelbooru]]({o})  |  [[Source]]({s})""".format(
            o=orig_url, s=source_url)

        embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA)
        embed.set_image(url=post_url)
        embed.set_author(name="Gelbooru",
                         icon_url="https://gelbooru.com/favicon.png")
        embed.add_field(name="\u200B", value=source_text, inline=False)
        return embed
예제 #11
0
class Responser(object):

    lines = []
    data = {}

    sessions = None

    def __init__(self, csv_path=None):
        self.sessions = \
            TTLCache(constants.SESSIONS_MAX_COUNT, ttl=constants.SESSION_TTL)

        # TODO fill this code when storage will be implemented
        # self.storage = Storage(csv_path=csv_path)

        # TODO remove this monkey patch code
        self.storage = type('Storage', (), dict(find=lambda x, y: None))
        self.storage.find_mock = lambda x, y: ['Apple', 'Applet', 'Appolo']

    def get_answer(self, question, chat_id):
        request = question.strip().lower()

        answ = None
        if chat_id in self.sessions.keys():
            session = self.sessions[chat_id]

            try:
                answ = session.get_answer(request)
            except Cancel as e:
                del self.sessions[chat_id]
                raise e

            del self.sessions[chat_id]

        if not answ:
            result = self.storage.find_mock(request, True)

            if not result:
                return None

            elif isinstance(result, list):
                try:
                    del self.sessions[chat_id]
                except KeyError:
                    pass

                self.sessions[chat_id] = PollSession(result)

                answ = self.make_poll_answer(result)
                return answ

            else:
                return result
        else:
            return answ

    def make_poll_answer(self, variants):
        lines = []
        for i in range(len(variants)):
            lines.append('{}. {}'.format(str(i+1).ljust(2), str(variants[i])))

        answ = '\n'.join(lines)
        answ = '{}\n0. {}'.format(answ, 'Отмена')
        answ = '{}\n\n{}'.format(answ, Answers.select_one)

        return answ
예제 #12
0
class e621:
    def __init__(self, bot):
        self.bot = bot
        self.session = requests.Session()
        self.cache = TTLCache(maxsize=500, ttl=300)
        self.url = "https://e621.net/post/index.xml?tags="
        self.default_tags = "+shota"

    @commands.command(name='fur', pass_context=True, no_pm=True)
    async def e621_search(self, ctx, *args):
        """: !fur <tags>        | Post random a image from e621"""

        channel = ctx.message.channel
        message = utils.parse_message(ctx.message.content,
                                      tags=self.default_tags)

        # Only respond in the furry channels.
        if "fur" not in str(channel):
            return

        r = requests.get(self.url + message)

        # If response is OK, continue.
        if r.status_code == 200:
            # parse page and get number of posts
            soup = BeautifulSoup(r.content, "xml")
            count = len(soup.find_all("post"))

            # Calculate number of pages (posts/limit), and search one at random.
            maxpage = int(round(count / 320))

            if maxpage < 1:
                maxpage = 1

            pid = list(range(0, maxpage))
            random.shuffle(pid)
            r = self.session.get(self.url + message + "&page=" + str(pid[0]))
            soup = BeautifulSoup(r.content, "lxml")
            posts = soup.find_all("post")

            if len(posts) is 0:
                msg = "No Results Found."
                await self.bot.send_message(ctx.message.channel, msg)
                return

            post = await self.getRandomPost(posts)

            if post:
                embed = self.create_embed(post)
                await self.bot.send_message(ctx.message.channel, embed=embed)

            else:
                msg = "All images have already been seen. Try again later."
                await self.bot.send_message(ctx.message.channel, msg)

    async def getRandomPost(self, posts):
        # create list of posts not recently seen.
        postlist = []
        for post in posts:
            postid = post.find('id').text
            if postid not in self.cache.keys():
                postlist.append(post)

        # if post list has posts, shuffle and select one
        if len(postlist):
            random.shuffle(postlist)
            postid = postlist[0].find('id').text
            self.cache[postid] = postid
            return postlist[0]

        else:
            return None

    def create_embed(self, post):
        post_url = post.find('file_url').text
        source_url = post.find("source").text
        postid = post.find("id").text
        orig_url = "https://e621.net/post/show/{id}".format(id=postid)
        embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA)
        embed.set_image(url=post_url)
        embed.add_field(
            name="e261",
            value="*View on:* [[e261]]({o})  |  [[Source]]({s})".format(
                o=orig_url, s=source_url),
            inline=False)
        return embed