class InMemoryCacheHandler: def __init__(self): dns_cache_size = int(setup_config_data_holder.dns_cache_size) dns_cache_ttl = int(setup_config_data_holder.dns_cache_ttl) self.cache = TTLCache(maxsize=dns_cache_size, ttl=dns_cache_ttl) def add_to_cache(self, key, value): self.cache[key] = value return self.cache def get_from_cache(self, key): return self.cache.get(key) def check_if_present(self, required_key): if required_key in self.cache.keys(): return True else: return False def show_cache(self): logger.debug(f"[Process: Show Cache], " f"CurrentCacheSize: [{self.cache.currsize}], " f"[MaxCacheSize: {self.cache.maxsize}], " f"[CacheTTL: {self.cache.ttl}], " f"[CacheTimer: {self.cache.timer}]") def clear_cache(self): logger.debug(f"[Process: Clearing Cache], " f"CurrentCacheSize: [{self.cache.currsize}], " f"[MaxCacheSize: {self.cache.maxsize}], " f"[CacheTTL: {self.cache.ttl}], " f"[CacheTimer: {self.cache.timer}]") self.cache.clear()
class WebSocket: def __init__(self, host, port, maxWebsockets, socketsTTL, processor): self.sockets=TTLCache(maxsize=maxWebsockets, ttl=socketsTTL) self.t=None self.host=host self.port=port self.processor=processor def start(self): self.t=Thread(target=self.serve_forever) self.t.start() logging.info("WEBSOCKET STARTED IN "+self.host+":"+str(self.port)) def stop(self): self.t._stop() def serve_forever(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) start_server = websockets.serve(self.handler, self.host, self.port) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() async def handler(self, websocket, path): while True: data = await websocket.recv() logging.info("WEBSOCKET RECEIVED "+data) jsonData=json.loads(data) token=jsonData["token"] if token in self.sockets: self.sockets[token].close() self.sockets[token]=websocket self.processor.checkPermissions(token) async def send(self, data, token): socket = self.sockets.get(token) if socket: await socket.send(data) def getUsers(self): return self.sockets.keys() def delToken(self, token): if token in self.sockets: self.sockets[token].close() del self.sockets[token]
class Shiba: def __init__(self, bot): self.bot = bot self.session = requests.Session() self.cache = TTLCache(maxsize=500, ttl=300) public_key = config.flickrPublic secret_key = config.flickrSecret flickr_api.set_keys(public_key, secret_key) @commands.command(pass_context=True, no_pm=True) async def shiba(self, ctx, *args): """: !shiba | post random Shibas""" shibas = flickr_api.Photo.search(per_page=500, tags='shiba dog') random.shuffle(shibas) sizes = ['Large', 'Medium', 'Small', 'Original'] while shibas[0]['id'] in self.cache.keys(): random.shuffle(shibas) photo_id = shibas[0]['id'] self.cache[photo_id] = photo_id for size in sizes: try: r = self.session.get(shibas[0].getPhotoUrl(size)) strainer = SoupStrainer('div', attrs={'id': 'allsizes-photo'}) soup = BeautifulSoup(r.content, "html.parser", parse_only=strainer) file_url = soup.find('img')['src'] embed = Embed() embed.set_author(name="Shiba", icon_url="http://i1.kym-cdn.com/entries/icons/facebook/000/013/564/aP2dv.jpg") embed.set_image(url=file_url) await self.bot.send_message(ctx.message.channel, embed=embed) return except: continue
class OrgAPI(object): def __init__(self, symbol, admin_mode=False, url=DEFAULT_RSI_URL, endpoint='/orgs', members_endpoint='/api/orgs/getOrgMembers', cache_ttl=DEFAULT_CACHE_TTL, session=None, username='', password=''): self.session = session self.symbol = symbol self.url = url.rstrip('/') self.endpoint = endpoint self.members_endpoint = members_endpoint self.admin_mode = admin_mode self.org_url = "{}/{}/{}".format(self.url, self.endpoint.lstrip('/'), symbol) self.members_api = "{}/{}".format(self.url, self.members_endpoint.lstrip('/')) self._ttlcache = TTLCache(maxsize=1, ttl=cache_ttl) if self.session is None: self.session = RSISession(url=url) if username and password: self.session.authenticate(username, password) self._update_details( ) # pull and cache the org details which will raise 404 if not found def clear_cache(self): """ Resets the cache """ for key in self._ttlcache.keys(): del self._ttlcache[key] def _cache(self, key, update_func, *args, **kwargs): if key not in self._ttlcache: self._ttlcache[key] = update_func(*args, **kwargs) return self._ttlcache[key] def _update_members(self, search): members = [] params = {'symbol': self.symbol, 'search': search, 'page': 1} if self.admin_mode: params['admin_mode'] = 1 # this just gets us going totalsize = 1 while len(members) < totalsize: r = self.session.post(self.members_api, data=params) if r.status_code == 200: r = r.json() if r is None: continue if 'data' in r and r['data'] and 'totalrows' in r['data']: totalsize = int(r['data']['totalrows']) if r['success'] == 1: apisoup = BeautifulSoup(r['data']['html'], features='lxml') for member in apisoup.select('.member-item'): members.append({ 'name': member.select_one('.name').text, 'handle': member.select_one('.nick').text, 'avatar': '{}{}'.format( self.url, member.select_one('img').attrs['src']), 'affiliate': member.select_one('.title').text == 'Affiliate', 'rank': member.select_one('.rank').text, 'roles': [_.text for _ in member.select('.rolelist .role')], 'url': '{}{}'.format( self.url, member.select_one( 'a.membercard').attrs['href']), }) if self.admin_mode: members[-1].update({ 'id': member.attrs.get('data-member-id', ''), 'last_online': member.select_one( '.frontinfo .lastonline').text, 'visibility ': member.select_one( '.frontinfo .visibility').text, }) params['page'] = params['page'] + 1 else: raise ValueError( 'Received error fetching Org members: {}'.format(r)) else: raise Exception( 'Received error fetching Org members: {}'.format( r.status_code)) time.sleep(0.5) return members def _update_details(self): r = self.session.get(self.org_url) data = {} r.raise_for_status() orgsoup = BeautifulSoup(r.text, features='lxml') data['banner'] = '{}{}'.format( self.url, orgsoup.select_one('.banner img')['src']) data['logo'] = '{}{}'.format(self.url, orgsoup.select_one('.logo img')['src']) data['name'], data['symbol'] = orgsoup.select_one( '.inner h1').text.split(' / ') data['model'] = orgsoup.select_one('.inner .tags .model').text data['commitment'] = orgsoup.select_one( '.inner .tags .commitment').text data['primary_focus'] = orgsoup.select_one( '.inner .focus .primary img')['alt'] data['secondary_focus'] = orgsoup.select_one( '.inner .focus .secondary img')['alt'] data['join_us'] = orgsoup.select_one('.join-us .body').text.strip() return data def search(self, handle, score_cutoff=80, limit=None): """ Return members that match the given handle using fuzzy matching. :param handle: Handle to match :param score_cutoff: minimum matching score to return :param limit: limit the number of matches found :return: List of matched results in the form of [(dict, int)] where dict is the ship data and in is the matching confidence """ choices = {i: _['name'] for i, _ in enumerate(self.members)} return [(self.members[_[2]], _[1]) for _ in process.extractBests( handle, choices, score_cutoff=score_cutoff, limit=limit)] def search_one(self, handle): """ Return the first member that matches the given handle using fuzzy matching, or None :param handle: Handle to match :return: The best matching member, or None """ choices = self.search(handle, limit=1) if choices: return choices[0][0] return None @property def members(self): return self._cache('members', self._update_members, search='') @property def details(self): return self._cache('details', self._update_details) @property def banner(self): return self.details['banner'] @property def logo(self): return self.details['logo'] @property def name(self): return self.details['name'] @property def model(self): return self.details['model'] @property def commitment(self): return self.details['commitment'] @property def primary_focus(self): return self.details['primary_focus'] @property def secondary_focus(self): return self.details['secondary_focus'] @property def spectrum_url(self): return '{}/spectrum/community/{}'.format(self.url, self.symbol) @property def join_us(self): return self.details['join_us']
class Bot: """The twitter bot.""" # Consts CALL_TEXT = "{} second encrypted call at {}" HASHTAGS = "#SeattleEncryptedComms" TWEET_PADDING = 20 BASE_URL = "https://api.openmhz.com/kcers1b/calls/newer?time={}&filter-type=talkgroup&filter-code=44912,45040,45112,45072,45136" # DEBUG URL TO GET A LOT OF API RESPONSES # BASE_URL = "https://api.openmhz.com/kcers1b/calls/newer?time={}" def __init__(self) -> None: """Initializes the class.""" self.callThreshold = int(os.getenv("CALL_THRESHOLD", 1)) self.debug = os.getenv("DEBUG", "true").lower() == "true" self.reportLatency = os.getenv("REPORT_LATENCY", "false").lower() == "true" self.window_minutes = int(os.getenv("WINDOW_M", 5)) self.timezone = pytz.timezone(os.getenv("TIMEZONE", "US/Pacific")) # The actual look back is the length of this lookback + lag compensation. For example: 300+45=345 seconds self.lookback = os.getenv("LOOKBACK_S", 300) self.cachedTweet: int = None self.cachedTime: datetime = None self.cache = TTLCache(maxsize=100, ttl=self.lookback) self.scraper = Scraper.Instance(self.BASE_URL, self.lookback) self.latency = [timedelta(seconds=0)] if not self.debug: # Does not need to be saved for later. # If the keys aren't in env this will still run. auth = tweepy.OAuthHandler(os.getenv("CONSUMER_KEY", ""), os.getenv("CONSUMER_SECRET", "")) auth.set_access_token(os.getenv("ACCESS_TOKEN_KEY", ""), os.getenv("ACCESS_TOKEN_SECRET", "")) self.api = tweepy.API(auth) # Test the authentication. This will gracefully fail if the keys aren't present. try: self.api.rate_limit_status() except TweepError as e: if e.api_code == 215: log.error("No keys or bad keys") else: log.error("Other API error: {}".format(e)) exit(1) self.interval = Set.Interval(30, self._check) def _kill(self) -> None: """This kills the c̶r̶a̶b̶ bot.""" self.interval.cancel() exit(0) def _getUniqueCalls(self, calls: dict) -> list: """Filters the return from the scraper to only tweet unique calls. Works by checking if the cache already has that call ID. Args: calls (list): The complete list of calls scraped. Returns: list: A filtered list of calls. """ res: List[dict] = [] for call in calls: # If the call is already in the cache skip. if call["_id"] in self.cache.keys(): continue # If it isn't, cache it and return it. else: # Might want to actually store somthing? Who knows. self.cache.update({call["_id"]: 0}) res.append(call) return res def _check(self) -> None: """Checks the API and sends a tweet if needed.""" try: log.info(f"Checking!: {datetime.now()}") try: json = self.scraper.getJSON() calls = self._getUniqueCalls(json["calls"]) log.info(f"Found {len(calls)} calls.") if len(calls) > 0: self._postTweet(calls) except TypeError as e: if json == None: # We already have an error message from the scraper return log.exception(e) return except KeyboardInterrupt as e: # Literally impossible to hit which might be an issue? Catching keyboard interrupt could happen in its own thread or something but that sounds complicated 👉👈 self._kill() if self.reportLatency: sum = sum(self.latency).total_seconds() avg = round(sum / len(self.latency), 3) log.info(f"Average latency for the last 100 calls: {avg} seconds") def _postTweet(self, calls: list) -> None: """Posts a tweet. Args: calls (list): The call objects to post about. """ # Filter to make sure that calls are actually recent. There can be a weird behavior of the API returning multiple hours old calls all at once. Also filters for calls under the length threshold. filteredCalls: List[dict] = [] for call in calls: diff = datetime.now(pytz.utc) - datetime.strptime( call["time"], "%Y-%m-%dT%H:%M:%S.000%z") if not abs(diff.total_seconds()) >= 1.8e3: if call["len"] < self.callThreshold: log.debug( f"Call of size {call['len']} below threshold ({self.callThreshold})" ) continue filteredCalls.append(call) if self.reportLatency: # Store latency self.latency.append(diff) if len(self.latency) > 100: self.latency.pop(0) if len(filteredCalls) == 0: # If there's nothing to post, simply leave return msgs = self._generateTweets(filteredCalls) if self.debug: msg = " | ".join(msgs) log.debug(f"Would have posted: {msg}") return # Check for a cached tweet, then check if the last tweet was less than the window ago. If the window has expired dereference the cached tweet. if (self.cachedTime != None and self.cachedTime + timedelta(minutes=self.window_minutes) <= datetime.now()): self.cachedTweet = None try: if self.cachedTweet != None: for msg in msgs: # Every time it posts the new ID gets stored so this works self.cachedTweet = self.api.update_status( msg, self.cachedTweet).id else: for index, msg in enumerate(msgs): if index == 0: # Since there isn't a cached tweet yet we have to send a non-reply first self.cachedTweet = self.api.update_status(msg).id else: self.cachedTweet = self.api.update_status( msg, self.cachedTweet).id self.cachedTime = datetime.now() except tweepy.TweepError as e: log.exception(e) def _timeString(self, call: dict) -> str: """Generates a time code string for a call. Args: call (dict): The call to get time from. Returns: str: A timestamp string in I:M:S am/pm format. """ # Get time from the call. date = datetime.strptime(call["time"], "%Y-%m-%dT%H:%M:%S.000%z") # F**k I hate how computer time works localized = date.replace(tzinfo=pytz.utc).astimezone(self.timezone) normalized = self.timezone.normalize(localized) return normalized.strftime("%#I:%M:%S %p") def _chunk(self, callStrings: list) -> list: """Chunks tweets into an acceptable length. Chunking. Shamelessly stolen from `SeattleDSA/signal_scanner_bot/twitter.py` :) Args: call_strings (list): List of strings derived from calls. Returns: list: A list of tweet strings to post """ tweetList: List[str] = [] baseIndex = 0 # Instead of spliting on words I want to split along call lines. subTweet: str = "" for index in range(len(callStrings)): if len(tweetList) == 0: subTweet = (", ".join(callStrings[baseIndex:index]) + " ... " + self.HASHTAGS) elif index < len(callStrings): subTweet = ", ".join(callStrings[baseIndex:index]) + " ..." elif index == len(callStrings): subTweet = ", ".join(callStrings[baseIndex:index]) if len(subTweet) > 280 - self.TWEET_PADDING: lastIndex = index - 1 tweetList.append(", ".join(callStrings[baseIndex:lastIndex]) + " ...") baseIndex = lastIndex tweetList.append(", ".join(callStrings[baseIndex:])) listLength = len(tweetList) for index in range(len(tweetList)): if index == 0: tweetList[ index] += f" {self.HASHTAGS} {index + 1}/{listLength}" else: tweetList[index] += f" {index + 1}/{listLength}" return tweetList def _generateTweets(self, calls: list) -> list: """Generates tweet messages. Args: call (list): The calls to tweet about. Returns: list: The tweet messages, hopefully right around the character limit. """ callStrings: List[str] = [] # First, take all of the calls and turn them into strings. for call in calls: callStrings.append( self.CALL_TEXT.format( call["len"], self._timeString(call), )) tweet = ", ".join(callStrings) + " " + self.HASHTAGS # If we don't have to chunk we can just leave. if len(tweet) <= 280: return [tweet] else: tweetList = self._chunk(callStrings) return tweetList
class WSNewBlocksWatcher(BaseWatcher): MESSAGE_TIMEOUT = 30.0 PING_TIMEOUT = 10.0 def __init__(self, w3: Web3, websocket_url): super().__init__(w3) self._network_on = False self._nonce: int = 0 self._current_block_number: int = -1 self._websocket_url = websocket_url self._node_address = None self._client: Optional[websockets.WebSocketClientProtocol] = None self._fetch_new_blocks_task: Optional[asyncio.Task] = None self._block_cache = TTLCache(maxsize=10, ttl=120) _nbw_logger: Optional[HummingbotLogger] = None @classmethod def logger(cls) -> HummingbotLogger: if cls._nbw_logger is None: cls._nbw_logger = logging.getLogger(__name__) return cls._nbw_logger @property def block_number(self) -> int: return self._current_block_number @property def block_cache(self) -> Dict[HexBytes, AttributeDict]: cache_dict: Dict[HexBytes, AttributeDict] = dict([(key, self._block_cache[key]) for key in self._block_cache.keys()]) return cache_dict async def start_network(self): if self._fetch_new_blocks_task is not None: await self.stop_network() else: try: self._current_block_number = await self.call_async(getattr, self._w3.eth, "blockNumber") except asyncio.CancelledError: raise except Exception: self.logger().network("Error fetching newest Ethereum block number.", app_warning_msg="Error fetching newest Ethereum block number. " "Check Ethereum node connection", exc_info=True) await self.connect() await self.subscribe(["newHeads"]) self._fetch_new_blocks_task: asyncio.Task = safe_ensure_future(self.fetch_new_blocks_loop()) self._network_on = True async def stop_network(self): if self._fetch_new_blocks_task is not None: await self.disconnect() self._fetch_new_blocks_task.cancel() self._fetch_new_blocks_task = None self._network_on = False async def connect(self): try: self._client = await websockets.connect(uri=self._websocket_url) return self._client except Exception as e: self.logger().network(f"ERROR in connection: {e}") async def disconnect(self): try: await self._client.close() except Exception as e: self.logger().network(f"ERROR in disconnection: {e}") async def _send(self, emit_data) -> int: self._nonce += 1 emit_data["id"] = self._nonce await self._client.send(ujson.dumps(emit_data)) return self._nonce async def subscribe(self, params) -> bool: emit_data = { "method": "eth_subscribe", "params": params } nonce = await self._send(emit_data) raw_message = await self._client.recv() if raw_message is not None: resp = ujson.loads(raw_message) if resp.get("id", None) == nonce: self._node_address = resp.get("result") return True return False async def _messages(self) -> AsyncIterable[Any]: try: while True: try: raw_msg_str: str = await asyncio.wait_for(self._client.recv(), self.MESSAGE_TIMEOUT) yield raw_msg_str except asyncio.TimeoutError: try: pong_waiter = await self._client.ping() await asyncio.wait_for(pong_waiter, timeout=self.PING_TIMEOUT) except asyncio.TimeoutError: raise except asyncio.TimeoutError: self.logger().warning("WebSocket ping timed out. Going to reconnect...") return except ConnectionClosed: return finally: await self.disconnect() async def fetch_new_blocks_loop(self): while True: try: async for raw_message in self._messages(): message_json = ujson.loads(raw_message) if raw_message is not None else None if message_json.get("method", None) == "eth_subscription": subscription_result_params = message_json.get("params", None) incoming_block = subscription_result_params.get("result", None) \ if subscription_result_params is not None else None if incoming_block is not None: new_block: AttributeDict = await self.call_async(self._w3.eth.getBlock, incoming_block.get("hash"), True) self._current_block_number = new_block.get("number") self._block_cache[new_block.get("hash")] = new_block self.trigger_event(NewBlocksWatcherEvent.NewBlocks, [new_block]) except asyncio.TimeoutError: self.logger().network("Timed out fetching new block.", exc_info=True, app_warning_msg="Timed out fetching new block. " "Check wallet network connection") except asyncio.CancelledError: raise except BlockNotFound: pass except Exception as e: self.logger().network(f"Error fetching new block: {e}", exc_info=True, app_warning_msg="Error fetching new block. " "Check wallet network connection") await asyncio.sleep(30.0) async def get_timestamp_for_block(self, block_hash: HexBytes, max_tries: Optional[int] = 10) -> int: counter = 0 block: AttributeDict = None if block_hash in self._block_cache.keys(): block = self._block_cache.get(block_hash) else: while block is None: if counter == max_tries: raise ValueError(f"Block hash {block_hash.hex()} does not exist.") counter += 1 block = self._block_cache.get(block_hash) await asyncio.sleep(0.5) return block.get("timestamp")
class SQLServer(AgentCheck): __NAMESPACE__ = 'sqlserver' def __init__(self, name, init_config, instances): super(SQLServer, self).__init__(name, init_config, instances) self._resolved_hostname = None self._agent_hostname = None self.connection = None self.failed_connections = {} self.instance_metrics = [] self.instance_per_type_metrics = defaultdict(set) self.do_check = True self.tags = self.instance.get("tags", []) self.reported_hostname = self.instance.get('reported_hostname') self.autodiscovery = is_affirmative(self.instance.get('database_autodiscovery')) self.autodiscovery_include = self.instance.get('autodiscovery_include', ['.*']) self.autodiscovery_exclude = self.instance.get('autodiscovery_exclude', []) self.autodiscovery_db_service_check = is_affirmative(self.instance.get('autodiscovery_db_service_check', True)) self.min_collection_interval = self.instance.get('min_collection_interval', 15) self._compile_patterns() self.autodiscovery_interval = self.instance.get('autodiscovery_interval', DEFAULT_AUTODISCOVERY_INTERVAL) self.databases = set() self.ad_last_check = 0 self.proc = self.instance.get('stored_procedure') self.proc_type_mapping = {'gauge': self.gauge, 'rate': self.rate, 'histogram': self.histogram} self.custom_metrics = init_config.get('custom_metrics', []) # DBM self.dbm_enabled = self.instance.get('dbm', False) self.statement_metrics_config = self.instance.get('query_metrics', {}) or {} self.statement_metrics = SqlserverStatementMetrics(self) self.activity_config = self.instance.get('query_activity', {}) or {} self.activity = SqlserverActivity(self) self.cloud_metadata = {} aws = self.instance.get('aws', {}) gcp = self.instance.get('gcp', {}) azure = self.instance.get('azure', {}) if aws: self.cloud_metadata.update({'aws': aws}) if gcp: self.cloud_metadata.update({'gcp': gcp}) if azure: self.cloud_metadata.update({'azure': azure}) obfuscator_options_config = self.instance.get('obfuscator_options', {}) or {} self.obfuscator_options = to_native_string( json.dumps( { # Valid values for this can be found at # https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/database.md#connection-level-attributes 'dbms': 'mssql', 'replace_digits': is_affirmative( obfuscator_options_config.get( 'replace_digits', obfuscator_options_config.get('quantize_sql_tables', False), ) ), 'keep_sql_alias': is_affirmative(obfuscator_options_config.get('keep_sql_alias', True)), 'return_json_metadata': is_affirmative(obfuscator_options_config.get('collect_metadata', True)), 'table_names': is_affirmative(obfuscator_options_config.get('collect_tables', True)), 'collect_commands': is_affirmative(obfuscator_options_config.get('collect_commands', True)), 'collect_comments': is_affirmative(obfuscator_options_config.get('collect_comments', True)), } ) ) self.static_info_cache = TTLCache( maxsize=100, # cache these for a full day ttl=60 * 60 * 24, ) # Query declarations check_queries = [] if is_affirmative(self.instance.get('include_ao_metrics', False)): check_queries.extend( [ QUERY_AO_AVAILABILITY_GROUPS, QUERY_AO_FAILOVER_CLUSTER, QUERY_AO_FAILOVER_CLUSTER_MEMBER, ] ) if is_affirmative(self.instance.get('include_fci_metrics', False)): check_queries.extend([QUERY_FAILOVER_CLUSTER_INSTANCE]) self._check_queries = self._new_query_executor(check_queries) self.check_initializations.append(self._check_queries.compile_queries) self.server_state_queries = self._new_query_executor([QUERY_SERVER_STATIC_INFO]) self.check_initializations.append(self.server_state_queries.compile_queries) # use QueryManager to process custom queries self._query_manager = QueryManager( self, self.execute_query_raw, tags=self.tags, hostname=self.resolved_hostname ) self._dynamic_queries = None self.check_initializations.append(self.config_checks) self.check_initializations.append(self._query_manager.compile_queries) self.check_initializations.append(self.initialize_connection) def cancel(self): self.statement_metrics.cancel() self.activity.cancel() def config_checks(self): if self.autodiscovery and self.instance.get('database'): self.log.warning( 'sqlserver `database_autodiscovery` and `database` options defined in same instance - ' 'autodiscovery will take precedence.' ) if not self.autodiscovery and (self.autodiscovery_include or self.autodiscovery_exclude): self.log.warning( "Autodiscovery is disabled, autodiscovery_include and autodiscovery_exclude will be ignored" ) def split_sqlserver_host_port(self, host): """ Splits the host & port out of the provided SQL Server host connection string, returning (host, port). """ if not host: return host, None host_split = [s.strip() for s in host.split(',')] if len(host_split) == 1: return host_split[0], None if len(host_split) == 2: return host_split # else len > 2 s_host, s_port = host_split[0:2] self.log.warning( "invalid sqlserver host string has more than one comma: %s. using only 1st two items: host:%s, port:%s", host, s_host, s_port, ) return s_host, s_port def _new_query_executor(self, queries): return QueryExecutor( self.execute_query_raw, self, queries=queries, tags=self.tags, hostname=self.resolved_hostname, ) @property def resolved_hostname(self): if self._resolved_hostname is None: if self.reported_hostname: self._resolved_hostname = self.reported_hostname elif self.dbm_enabled: host, port = self.split_sqlserver_host_port(self.instance.get('host')) self._resolved_hostname = resolve_db_host(host) else: self._resolved_hostname = self.agent_hostname return self._resolved_hostname def load_static_information(self): expected_keys = {STATIC_INFO_VERSION, STATIC_INFO_MAJOR_VERSION, STATIC_INFO_ENGINE_EDITION} missing_keys = expected_keys - set(self.static_info_cache.keys()) if missing_keys: with self.connection.open_managed_default_connection(): with self.connection.get_managed_cursor() as cursor: if STATIC_INFO_VERSION not in self.static_info_cache: cursor.execute("select @@version") results = cursor.fetchall() if results and len(results) > 0 and len(results[0]) > 0 and results[0][0]: version = results[0][0] self.static_info_cache[STATIC_INFO_VERSION] = version self.static_info_cache[STATIC_INFO_MAJOR_VERSION] = parse_sqlserver_major_version(version) if not self.static_info_cache[STATIC_INFO_MAJOR_VERSION]: self.log.warning("failed to parse SQL Server major version from version: %s", version) else: self.log.warning("failed to load version static information due to empty results") if STATIC_INFO_ENGINE_EDITION not in self.static_info_cache: cursor.execute("SELECT CAST(ServerProperty('EngineEdition') AS INT) AS Edition") result = cursor.fetchone() if result: self.static_info_cache[STATIC_INFO_ENGINE_EDITION] = result else: self.log.warning("failed to load version static information due to empty results") def debug_tags(self): return self.tags + ['agent_hostname:{}'.format(self.agent_hostname)] def debug_stats_kwargs(self, tags=None): tags = tags if tags else [] return { "tags": self.debug_tags() + tags, "hostname": self.resolved_hostname, "raw": True, } @property def agent_hostname(self): # type: () -> str if self._agent_hostname is None: self._agent_hostname = datadog_agent.get_hostname() return self._agent_hostname def initialize_connection(self): self.connection = Connection(self.init_config, self.instance, self.handle_service_check) # Pre-process the list of metrics to collect try: # check to see if the database exists before we try any connections to it db_exists, context = self.connection.check_database() if db_exists: if self.instance.get('stored_procedure') is None: with self.connection.open_managed_default_connection(): with self.connection.get_managed_cursor() as cursor: self.autodiscover_databases(cursor) self._make_metric_list_to_collect(self.custom_metrics) else: # How much do we care that the DB doesn't exist? ignore = is_affirmative(self.instance.get("ignore_missing_database", False)) if ignore is not None and ignore: # not much : we expect it. leave checks disabled self.do_check = False self.log.warning("Database %s does not exist. Disabling checks for this instance.", context) else: # yes we do. Keep trying msg = "Database {} does not exist. Please resolve invalid database and restart agent".format( context ) raise ConfigurationError(msg) except SQLConnectionError as e: self.log.exception("Error connecting to database: %s", e) except ConfigurationError: raise except Exception as e: self.log.exception("Initialization exception %s", e) def handle_service_check(self, status, host, database, message=None, is_default=True): custom_tags = self.instance.get("tags", []) disable_generic_tags = self.instance.get('disable_generic_tags', False) if custom_tags is None: custom_tags = [] if disable_generic_tags: service_check_tags = ['sqlserver_host:{}'.format(host), 'db:{}'.format(database)] else: service_check_tags = ['host:{}'.format(host), 'sqlserver_host:{}'.format(host), 'db:{}'.format(database)] service_check_tags.extend(custom_tags) service_check_tags = list(set(service_check_tags)) if status is AgentCheck.OK: message = None if is_default: self.service_check(SERVICE_CHECK_NAME, status, tags=service_check_tags, message=message, raw=True) if self.autodiscovery and self.autodiscovery_db_service_check: self.service_check(DATABASE_SERVICE_CHECK_NAME, status, tags=service_check_tags, message=message, raw=True) def _compile_patterns(self): self._include_patterns = self._compile_valid_patterns(self.autodiscovery_include) self._exclude_patterns = self._compile_valid_patterns(self.autodiscovery_exclude) def _compile_valid_patterns(self, patterns): valid_patterns = [] for pattern in patterns: # Ignore empty patterns as they match everything if not pattern: continue try: re.compile(pattern, re.IGNORECASE) except Exception: self.log.warning('%s is not a valid regular expression and will be ignored', pattern) else: valid_patterns.append(pattern) if valid_patterns: return re.compile('|'.join(valid_patterns), re.IGNORECASE) else: # create unmatchable regex - https://stackoverflow.com/a/1845097/2157429 return re.compile(r'(?!x)x') def autodiscover_databases(self, cursor): if not self.autodiscovery: return False now = time.time() if now - self.ad_last_check > self.autodiscovery_interval: self.log.info('Performing database autodiscovery') cursor.execute(AUTODISCOVERY_QUERY) all_dbs = set(row.name for row in cursor.fetchall()) excluded_dbs = set([d for d in all_dbs if self._exclude_patterns.match(d)]) included_dbs = set([d for d in all_dbs if self._include_patterns.match(d)]) self.log.debug( 'Autodiscovered databases: %s, excluding: %s, including: %s', all_dbs, excluded_dbs, included_dbs ) # keep included dbs but remove any that were explicitly excluded filtered_dbs = all_dbs.intersection(included_dbs) - excluded_dbs self.log.debug('Resulting filtered databases: %s', filtered_dbs) self.ad_last_check = now if filtered_dbs != self.databases: self.log.debug('Databases updated from previous autodiscovery check.') self.databases = filtered_dbs return True return False def _make_metric_list_to_collect(self, custom_metrics): """ Store the list of metrics to collect by instance_key. Will also create and cache cursors to query the db. """ metrics_to_collect = [] tags = self.instance.get('tags', []) # Load instance-level (previously Performance) metrics) # If several check instances are querying the same server host, it can be wise to turn these off # to avoid sending duplicate metrics if is_affirmative(self.instance.get('include_instance_metrics', True)): common_metrics = INSTANCE_METRICS if not self.dbm_enabled: common_metrics = common_metrics + DBM_MIGRATED_METRICS self._add_performance_counters( chain(common_metrics, INSTANCE_METRICS_TOTAL), metrics_to_collect, tags, db=None ) # populated through autodiscovery if self.databases: for db in self.databases: self._add_performance_counters(INSTANCE_METRICS_TOTAL, metrics_to_collect, tags, db=db) # Load database statistics for name, table, column in DATABASE_METRICS: # include database as a filter option db_names = self.databases or [self.instance.get('database', self.connection.DEFAULT_DATABASE)] for db_name in db_names: cfg = {'name': name, 'table': table, 'column': column, 'instance_name': db_name, 'tags': tags} metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column)) # Load AlwaysOn metrics if is_affirmative(self.instance.get('include_ao_metrics', False)): for name, table, column in AO_METRICS + AO_METRICS_PRIMARY + AO_METRICS_SECONDARY: db_name = 'master' cfg = { 'name': name, 'table': table, 'column': column, 'instance_name': db_name, 'tags': tags, 'ao_database': self.instance.get('ao_database', None), 'availability_group': self.instance.get('availability_group', None), 'only_emit_local': is_affirmative(self.instance.get('only_emit_local', False)), } metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column)) # Load metrics from scheduler and task tables, if enabled if is_affirmative(self.instance.get('include_task_scheduler_metrics', False)): for name, table, column in TASK_SCHEDULER_METRICS: cfg = {'name': name, 'table': table, 'column': column, 'tags': tags} metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column)) # Load sys.master_files metrics if is_affirmative(self.instance.get('include_master_files_metrics', False)): for name, table, column in DATABASE_MASTER_FILES: cfg = {'name': name, 'table': table, 'column': column, 'tags': tags} metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column)) # Load DB Fragmentation metrics if is_affirmative(self.instance.get('include_db_fragmentation_metrics', False)): db_fragmentation_object_names = self.instance.get('db_fragmentation_object_names', []) db_names = self.databases or [self.instance.get('database', self.connection.DEFAULT_DATABASE)] if not db_fragmentation_object_names: self.log.debug( "No fragmentation object names specified, will return fragmentation metrics for all " "object_ids of current database(s): %s", db_names, ) for db_name in db_names: for name, table, column in DATABASE_FRAGMENTATION_METRICS: cfg = { 'name': name, 'table': table, 'column': column, 'instance_name': db_name, 'tags': tags, 'db_fragmentation_object_names': db_fragmentation_object_names, } metrics_to_collect.append(self.typed_metric(cfg_inst=cfg, table=table, column=column)) # Load any custom metrics from conf.d/sqlserver.yaml for cfg in custom_metrics: sql_type = None base_name = None custom_tags = tags + cfg.get('tags', []) cfg['tags'] = custom_tags db_table = cfg.get('table', DEFAULT_PERFORMANCE_TABLE) if db_table not in VALID_TABLES: self.log.error('%s has an invalid table name: %s', cfg['name'], db_table) continue if cfg.get('database', None) and cfg.get('database') != self.instance.get('database'): self.log.debug( 'Skipping custom metric %s for database %s, check instance configured for database %s', cfg['name'], cfg.get('database'), self.instance.get('database'), ) continue if db_table == DEFAULT_PERFORMANCE_TABLE: user_type = cfg.get('type') if user_type is not None and user_type not in VALID_METRIC_TYPES: self.log.error('%s has an invalid metric type: %s', cfg['name'], user_type) sql_type = None try: if user_type is None: sql_type, base_name = self.get_sql_type(cfg['counter_name']) except Exception: self.log.warning("Can't load the metric %s, ignoring", cfg['name'], exc_info=True) continue metrics_to_collect.append( self.typed_metric( cfg_inst=cfg, table=db_table, base_name=base_name, user_type=user_type, sql_type=sql_type ) ) else: for column in cfg['columns']: metrics_to_collect.append( self.typed_metric( cfg_inst=cfg, table=db_table, base_name=base_name, sql_type=sql_type, column=column ) ) self.instance_metrics = metrics_to_collect self.log.debug("metrics to collect %s", metrics_to_collect) # create an organized grouping of metric names to their metric classes for m in metrics_to_collect: cls = m.__class__.__name__ name = m.sql_name or m.column self.log.debug("Adding metric class %s named %s", cls, name) self.instance_per_type_metrics[cls].add(name) if m.base_name: self.instance_per_type_metrics[cls].add(m.base_name) def _add_performance_counters(self, metrics, metrics_to_collect, tags, db=None): if db is not None: tags = tags + ['database:{}'.format(db)] for name, counter_name, instance_name in metrics: try: sql_type, base_name = self.get_sql_type(counter_name) cfg = { 'name': name, 'counter_name': counter_name, 'instance_name': db or instance_name, 'tags': tags, } metrics_to_collect.append( self.typed_metric( cfg_inst=cfg, table=DEFAULT_PERFORMANCE_TABLE, base_name=base_name, sql_type=sql_type ) ) except SQLConnectionError: raise except Exception: self.log.warning("Can't load the metric %s, ignoring", name, exc_info=True) continue def get_sql_type(self, counter_name): """ Return the type of the performance counter so that we can report it to Datadog correctly If the sql_type is one that needs a base (PERF_RAW_LARGE_FRACTION and PERF_AVERAGE_BULK), the name of the base counter will also be returned """ with self.connection.get_managed_cursor() as cursor: cursor.execute(COUNTER_TYPE_QUERY, (counter_name,)) (sql_type,) = cursor.fetchone() if sql_type == PERF_LARGE_RAW_BASE: self.log.warning("Metric %s is of type Base and shouldn't be reported this way", counter_name) base_name = None if sql_type in [PERF_AVERAGE_BULK, PERF_RAW_LARGE_FRACTION]: # This is an ugly hack. For certains type of metric (PERF_RAW_LARGE_FRACTION # and PERF_AVERAGE_BULK), we need two metrics: the metrics specified and # a base metrics to get the ratio. There is no unique schema so we generate # the possible candidates and we look at which ones exist in the db. candidates = ( counter_name + " base", counter_name.replace("(ms)", "base"), counter_name.replace("Avg ", "") + " base", ) try: cursor.execute(BASE_NAME_QUERY, candidates) base_name = cursor.fetchone().counter_name.strip() self.log.debug("Got base metric: %s for metric: %s", base_name, counter_name) except Exception as e: self.log.warning("Could not get counter_name of base for metric: %s", e) return sql_type, base_name def typed_metric(self, cfg_inst, table, base_name=None, user_type=None, sql_type=None, column=None): """ Create the appropriate BaseSqlServerMetric object, each implementing its method to fetch the metrics properly. If a `type` was specified in the config, it is used to report the value directly fetched from SQLServer. Otherwise, it is decided based on the sql_type, according to microsoft's documentation. """ if table == DEFAULT_PERFORMANCE_TABLE: metric_type_mapping = { PERF_COUNTER_BULK_COUNT: (self.rate, metrics.SqlSimpleMetric), PERF_COUNTER_LARGE_RAWCOUNT: (self.gauge, metrics.SqlSimpleMetric), PERF_LARGE_RAW_BASE: (self.gauge, metrics.SqlSimpleMetric), PERF_RAW_LARGE_FRACTION: (self.gauge, metrics.SqlFractionMetric), PERF_AVERAGE_BULK: (self.gauge, metrics.SqlIncrFractionMetric), } if user_type is not None: # user type overrides any other value metric_type = getattr(self, user_type) cls = metrics.SqlSimpleMetric else: metric_type, cls = metric_type_mapping[sql_type] else: # Lookup metrics classes by their associated table metric_type_str, cls = metrics.TABLE_MAPPING[table] metric_type = getattr(self, metric_type_str) cfg_inst['hostname'] = self.resolved_hostname return cls(cfg_inst, base_name, metric_type, column, self.log) def check(self, _): if self.do_check: self.load_static_information() if self.proc: self.do_stored_procedure_check() else: self.collect_metrics() if self.autodiscovery and self.autodiscovery_db_service_check: for db_name in self.databases: if db_name != self.connection.DEFAULT_DATABASE: try: self.connection.check_database_conns(db_name) except Exception as e: # service_check errors on auto discovered databases should not abort the check self.log.warning("failed service check for auto discovered database: %s", e) if self.dbm_enabled: self.statement_metrics.run_job_loop(self.tags) self.activity.run_job_loop(self.tags) else: self.log.debug("Skipping check") @property def dynamic_queries(self): """ Initializes dynamic queries which depend on static information loaded from the database """ if self._dynamic_queries: return self._dynamic_queries major_version = self.static_info_cache.get(STATIC_INFO_MAJOR_VERSION) if not major_version: self.log.warning("missing major_version, cannot initialize dynamic queries") return None queries = [get_query_file_stats(major_version)] self._dynamic_queries = self._new_query_executor(queries) self._dynamic_queries.compile_queries() self.log.debug("initialized dynamic queries") return self._dynamic_queries def collect_metrics(self): """Fetch the metrics from all of the associated database tables.""" with self.connection.open_managed_default_connection(): with self.connection.get_managed_cursor() as cursor: # initiate autodiscovery or if the server was down at check __init__ key could be missing. if self.autodiscover_databases(cursor) or not self.instance_metrics: self._make_metric_list_to_collect(self.custom_metrics) instance_results = {} # Execute the `fetch_all` operations first to minimize the database calls for cls, metric_names in six.iteritems(self.instance_per_type_metrics): if not metric_names: instance_results[cls] = None, None else: try: db_names = self.databases or [ self.instance.get('database', self.connection.DEFAULT_DATABASE) ] rows, cols = getattr(metrics, cls).fetch_all_values( cursor, list(metric_names), self.log, databases=db_names ) except Exception as e: self.log.error("Error running `fetch_all` for metrics %s - skipping. Error: %s", cls, e) rows, cols = None, None instance_results[cls] = rows, cols # Using the cached data, extract and report individual metrics for metric in self.instance_metrics: if type(metric) is metrics.SqlIncrFractionMetric: # special case, since it uses the same results as SqlFractionMetric key = 'SqlFractionMetric' else: key = metric.__class__.__name__ if key not in instance_results: self.log.warning("No %s metrics found, skipping", str(key)) else: rows, cols = instance_results[key] if rows is not None: metric.fetch_metric(rows, cols) # Neither pyodbc nor adodbapi are able to read results of a query if the number of rows affected # statement are returned as part of the result set, so we disable for the entire connection # this is important mostly for custom_queries or the stored_procedure feature # https://docs.microsoft.com/en-us/sql/t-sql/statements/set-nocount-transact-sql with self.connection.get_managed_cursor() as cursor: cursor.execute("SET NOCOUNT ON") try: # Server state queries require VIEW SERVER STATE permissions, which some managed database # versions do not support. if self.static_info_cache.get(STATIC_INFO_ENGINE_EDITION) not in [ ENGINE_EDITION_SQL_DATABASE, ]: self.server_state_queries.execute() self._check_queries.execute() if self.dynamic_queries: self.dynamic_queries.execute() # reuse connection for any custom queries self._query_manager.execute() finally: with self.connection.get_managed_cursor() as cursor: cursor.execute("SET NOCOUNT OFF") def execute_query_raw(self, query): with self.connection.get_managed_cursor() as cursor: cursor.execute(query) return cursor.fetchall() def do_stored_procedure_check(self): """ Fetch the metrics from the stored proc """ proc = self.proc guardSql = self.instance.get('proc_only_if') custom_tags = self.instance.get("tags", []) if (guardSql and self.proc_check_guard(guardSql)) or not guardSql: self.connection.open_db_connections(self.connection.DEFAULT_DB_KEY) cursor = self.connection.get_cursor(self.connection.DEFAULT_DB_KEY) try: self.log.debug("Calling Stored Procedure : %s", proc) if self.connection.get_connector() == 'adodbapi': cursor.callproc(proc) else: # pyodbc does not support callproc; use execute instead. # Reference: https://github.com/mkleehammer/pyodbc/wiki/Calling-Stored-Procedures call_proc = '{{CALL {}}}'.format(proc) cursor.execute(call_proc) rows = cursor.fetchall() self.log.debug("Row count (%s) : %s", proc, cursor.rowcount) for row in rows: tags = [] if row.tags is None or row.tags == '' else row.tags.split(',') tags.extend(custom_tags) if row.type.lower() in self.proc_type_mapping: self.proc_type_mapping[row.type](row.metric, row.value, tags, raw=True) else: self.log.warning( '%s is not a recognised type from procedure %s, metric %s', row.type, proc, row.metric ) except Exception as e: self.log.warning("Could not call procedure %s: %s", proc, e) raise e self.connection.close_cursor(cursor) self.connection.close_db_connections(self.connection.DEFAULT_DB_KEY) else: self.log.info("Skipping call to %s due to only_if", proc) def proc_check_guard(self, sql): """ check to see if the guard SQL returns a single column containing 0 or 1 We return true if 1, else False """ self.connection.open_db_connections(self.connection.PROC_GUARD_DB_KEY) cursor = self.connection.get_cursor(self.connection.PROC_GUARD_DB_KEY) should_run = False try: cursor.execute(sql, ()) result = cursor.fetchone() should_run = result[0] == 1 except Exception as e: self.log.error("Failed to run proc_only_if sql %s : %s", sql, e) self.connection.close_cursor(cursor) self.connection.close_db_connections(self.connection.PROC_GUARD_DB_KEY) return should_run
class Shotachan: def __init__(self, bot): self.bot = bot self.session = requests.Session() self.cache = TTLCache(maxsize=500, ttl=300) self.count_url = "http://booru.shotachan.net/post?tags=" self.post_url = "http://booru.shotachan.net/post/index.json?tags=" @commands.command(name='shota', pass_context=True, no_pm=True) async def shotachan_search(self, ctx, *args): """: !shota <tags> | Post a random image from the Shotachan Booru.""" # if no tags, default message = utils.parse_message(ctx.message.content) # get number of pages maxpage = self.shotachan_count(message) # Choose random page number pid = list(range(0, int(maxpage))) random.shuffle(pid) # get posts url = self.post_url + message + "&page=" + str(pid[0]) posts = json.loads(self.session.get(url).text) count = len(posts) # get random post if posts: post = await self.getRandomPost(posts) else: msg = "No Results Found." await self.bot.send_message(ctx.message.channel, msg) return # if random post returned, create embed. if post: embed = self.create_embed(post) await self.bot.send_message(ctx.message.channel, embed=embed) else: msg = "All images have already been seen. Try again later." await self.bot.send_message(ctx.message.channel, msg) return @lru_cache(maxsize=None) def shotachan_count(self, message): url = self.count_url + message r = self.session.get(url) try: # Get total number of posts. strainer = SoupStrainer('div', attrs={'id': 'paginator'}) soup = BeautifulSoup(r.content, "html.parser", parse_only=strainer) maxpage = soup.find_all('a')[-2].getText() except: return 1 return maxpage async def getRandomPost(self, posts): # create list of posts not recently seen. postlist = [] for post in posts: if post['id'] not in self.cache.keys(): postlist.append(post) # if post list has posts, shuffle and select one if len(postlist): random.shuffle(postlist) postid = postlist[0]['id'] self.cache[postid] = postid return postlist[0] else: return None def create_embed(self, post): postid = str(post['id']) post_url = str(post['file_url']) source_url = str(post['source']) orig_url = "http://booru.shotachan.net/post/show/{id}".format( id=postid) embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA) embed.set_image(url=post_url) embed.set_author(name="Shotachan", icon_url="http://booru.shotachan.net/favicon.ico") embed.add_field( name="\u200B", value="*View on:* [[Shotachan]]({o}) | [[Source]]({s})".format( o=orig_url, s=source_url), inline=False) return embed
class MemoCache(Cache): """Manages cached values for a single st.memo-ized function.""" def __init__( self, key: str, persist: Optional[str], max_entries: float, ttl: float, display_name: str, ): self.key = key self.display_name = display_name self.persist = persist self._mem_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER) self._mem_cache_lock = threading.Lock() @property def max_entries(self) -> float: return cast(float, self._mem_cache.maxsize) @property def ttl(self) -> float: return cast(float, self._mem_cache.ttl) def get_stats(self) -> List[CacheStat]: stats: List[CacheStat] = [] with self._mem_cache_lock: for item_key, item_value in self._mem_cache.items(): stats.append( CacheStat( category_name="st_memo", cache_name=self.display_name, byte_length=len(item_value), )) return stats def read_value(self, key: str) -> Any: """Read a value from the cache. Raise `CacheKeyNotFoundError` if the value doesn't exist, and `CacheError` if the value exists but can't be unpickled. """ try: pickled_value = self._read_from_mem_cache(key) except CacheKeyNotFoundError as e: if self.persist == "disk": pickled_value = self._read_from_disk_cache(key) self._write_to_mem_cache(key, pickled_value) else: raise e try: return pickle.loads(pickled_value) except pickle.UnpicklingError as exc: raise CacheError(f"Failed to unpickle {key}") from exc def write_value(self, key: str, value: Any) -> None: """Write a value to the cache. It must be pickleable.""" try: pickled_value = pickle.dumps(value) except pickle.PicklingError as exc: raise CacheError(f"Failed to pickle {key}") from exc self._write_to_mem_cache(key, pickled_value) if self.persist == "disk": self._write_to_disk_cache(key, pickled_value) def clear(self) -> None: with self._mem_cache_lock: # We keep a lock for the entirety of the clear operation to avoid # disk cache race conditions. for key in self._mem_cache.keys(): self._remove_from_disk_cache(key) self._mem_cache.clear() def _read_from_mem_cache(self, key: str) -> bytes: with self._mem_cache_lock: if key in self._mem_cache: entry = bytes(self._mem_cache[key]) _LOGGER.debug("Memory cache HIT: %s", key) return entry else: _LOGGER.debug("Memory cache MISS: %s", key) raise CacheKeyNotFoundError("Key not found in mem cache") def _read_from_disk_cache(self, key: str) -> bytes: path = self._get_file_path(key) try: with streamlit_read(path, binary=True) as input: value = input.read() _LOGGER.debug("Disk cache HIT: %s", key) return bytes(value) except FileNotFoundError: raise CacheKeyNotFoundError("Key not found in disk cache") except BaseException as e: _LOGGER.error(e) raise CacheError("Unable to read from cache") from e def _write_to_mem_cache(self, key: str, pickled_value: bytes) -> None: with self._mem_cache_lock: self._mem_cache[key] = pickled_value def _write_to_disk_cache(self, key: str, pickled_value: bytes) -> None: path = self._get_file_path(key) try: with streamlit_write(path, binary=True) as output: output.write(pickled_value) except util.Error as e: _LOGGER.debug(e) # Clean up file so we don't leave zero byte files. try: os.remove(path) except (FileNotFoundError, IOError, OSError): pass raise CacheError("Unable to write to cache") from e def _remove_from_disk_cache(self, key: str) -> None: """Delete a cache file from disk. If the file does not exist on disk, return silently. If another exception occurs, log it. Does not throw. """ path = self._get_file_path(key) try: os.remove(path) except FileNotFoundError: pass except BaseException as e: _LOGGER.exception("Unable to remove a file from the disk cache", e) def _get_file_path(self, value_key: str) -> str: """Return the path of the disk cache file for the given value.""" return get_streamlit_file_path(_CACHE_DIR_NAME, f"{self.key}-{value_key}.memo")
class Gelbooru: def __init__(self, bot): self.cache = TTLCache(maxsize=500, ttl=300) self.bot = bot self.session = requests.Session() self.url = "http://gelbooru.com/index.php?page=dapi&s=post&q=index&tags=" self.post_url = "http://gelbooru.com/index.php?page=post&s=view&id=" self.default_tags = None self.tags_file = "./config/gelbooru_defaulttags.yaml" self.loadDefaultTags() @commands.group(name='cum', pass_context=True, no_pm=True, invoke_without_command=True) async def gelbooru_search(self, ctx, *args): """: !cum <tags> | Post a random image from Gelboorur""" channel = str(ctx.message.channel) if ctx.invoked_subcommand is None: message = utils.parse_message(ctx.message.content, tags=self.default_tags[channel]) count = self.gelbooru_count(message) if count: pageid = self.getRandomPage(count) else: msg = "No Results Found." await self.bot.send_message(ctx.message.channel, msg) return # Search r = self.session.get(self.url + message + "&pid=" + str(pageid)) soup = BeautifulSoup(r.content, "html.parser") posts = soup.find_all("post") post = await self.getRandomPost(posts, count) if post: embed = self.create_embed(post) await self.bot.send_message(ctx.message.channel, embed=embed) else: msg = "All images already seen, Try again later." await self.bot.send_message(ctx.message.channel, msg) return @gelbooru_search.group(name="default_tags", pass_context=True, no_pm=True, invoke_without_command=True) async def defaultTags(self, ctx): if ctx.invoked_subcommand is None: channel = str(ctx.message.channel) self.checkDefaultTags(channel) await self.bot.send_message(ctx.message.channel, self.default_tags[channel]) @defaultTags.command(name="add", pass_context=True, no_pm=True) async def blacklist_add(self, ctx): #if str(ctx.message.author.id) not in admins: # return admins = [a for a in self.bot.session.query(db.Admin.userid).all()] admins = [ a.userid for a in self.bot.session.query(db.Admin.userid).all() ] if str(ctx.message.author.id) not in admins: return msg = ctx.message.content.split(' ')[3:] channel = str(ctx.message.channel) self.checkDefaultTags(channel) writefile = False for tag in msg: if tag not in self.default_tags[channel]: self.default_tags[channel].append(tag) writefile = True if writefile: with open(self.tags_file, 'w') as tagfile: yaml.dump(self.default_tags, tagfile, default_flow_style=False) await self.bot.send_message(ctx.message.channel, self.default_tags[channel]) @defaultTags.command(name="remove", pass_context=True, no_pm=True) async def blacklist_remove(self, ctx): admins = [a for a in self.bot.session.query(db.Admin.userid).all()] admins = [ a.userid for a in self.bot.session.query(db.Admin.userid).all() ] if str(ctx.message.author.id) not in admins: return msg = ctx.message.content.split(' ')[3:] channel = str(ctx.message.channel) self.checkDefaultTags(channel) writefile = False for tag in msg: if tag in self.default_tags[channel]: self.default_tags[channel].remove(tag) writefile = True if writefile: with open(self.default_tags, 'w') as tagfile: yaml.dump(self.default_tags, tagfile, default_flow_style=False) await self.bot.send_message(ctx.message.channel, self.default_tags[channel]) def loadDefaultTags(self): if Path(self.tags_file).is_file(): with open(self.tags_file, 'r') as tagfile: self.default_tags = yaml.load(tagfile) else: open(self.tags_file, 'x') if self.default_tags is None: self.default_tags = dict() def checkDefaultTags(self, channel): if channel not in self.default_tags: self.default_tags[channel] = ['shota'] async def getRandomPost(self, posts, count): # create list of posts not recently seen. postlist = [] for post in posts: if post['id'] not in self.cache.keys(): postlist.append(post) # if post list has posts, shuffle and select one if len(postlist): random.shuffle(postlist) postid = postlist[0]['id'] self.cache[postid] = postid return postlist[0] def getRandomPage(self, count): maxpage = int(round(count / 100)) if maxpage < 1: maxpage = 1 # return random page number from range(0, maxpage) pageid = random.sample(list(range(0, maxpage)), 1)[0] return pageid @lru_cache(maxsize=None) def gelbooru_count(self, message): # get total number of posts r = self.session.get(self.url + message) if r.status_code == 200: soup = BeautifulSoup(r.content, "html.parser") count = int(soup.find("posts")['count']) if count: return count def create_embed(self, post): # create discord embed from post postid = str(post['id']) post_url = "https:" + str(post['file_url']) source_url = str(post['source']) orig_url = self.post_url + postid source_text = """*View on:* [[Gelbooru]]({o}) | [[Source]]({s})""".format( o=orig_url, s=source_url) embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA) embed.set_image(url=post_url) embed.set_author(name="Gelbooru", icon_url="https://gelbooru.com/favicon.png") embed.add_field(name="\u200B", value=source_text, inline=False) return embed
class Responser(object): lines = [] data = {} sessions = None def __init__(self, csv_path=None): self.sessions = \ TTLCache(constants.SESSIONS_MAX_COUNT, ttl=constants.SESSION_TTL) # TODO fill this code when storage will be implemented # self.storage = Storage(csv_path=csv_path) # TODO remove this monkey patch code self.storage = type('Storage', (), dict(find=lambda x, y: None)) self.storage.find_mock = lambda x, y: ['Apple', 'Applet', 'Appolo'] def get_answer(self, question, chat_id): request = question.strip().lower() answ = None if chat_id in self.sessions.keys(): session = self.sessions[chat_id] try: answ = session.get_answer(request) except Cancel as e: del self.sessions[chat_id] raise e del self.sessions[chat_id] if not answ: result = self.storage.find_mock(request, True) if not result: return None elif isinstance(result, list): try: del self.sessions[chat_id] except KeyError: pass self.sessions[chat_id] = PollSession(result) answ = self.make_poll_answer(result) return answ else: return result else: return answ def make_poll_answer(self, variants): lines = [] for i in range(len(variants)): lines.append('{}. {}'.format(str(i+1).ljust(2), str(variants[i]))) answ = '\n'.join(lines) answ = '{}\n0. {}'.format(answ, 'Отмена') answ = '{}\n\n{}'.format(answ, Answers.select_one) return answ
class e621: def __init__(self, bot): self.bot = bot self.session = requests.Session() self.cache = TTLCache(maxsize=500, ttl=300) self.url = "https://e621.net/post/index.xml?tags=" self.default_tags = "+shota" @commands.command(name='fur', pass_context=True, no_pm=True) async def e621_search(self, ctx, *args): """: !fur <tags> | Post random a image from e621""" channel = ctx.message.channel message = utils.parse_message(ctx.message.content, tags=self.default_tags) # Only respond in the furry channels. if "fur" not in str(channel): return r = requests.get(self.url + message) # If response is OK, continue. if r.status_code == 200: # parse page and get number of posts soup = BeautifulSoup(r.content, "xml") count = len(soup.find_all("post")) # Calculate number of pages (posts/limit), and search one at random. maxpage = int(round(count / 320)) if maxpage < 1: maxpage = 1 pid = list(range(0, maxpage)) random.shuffle(pid) r = self.session.get(self.url + message + "&page=" + str(pid[0])) soup = BeautifulSoup(r.content, "lxml") posts = soup.find_all("post") if len(posts) is 0: msg = "No Results Found." await self.bot.send_message(ctx.message.channel, msg) return post = await self.getRandomPost(posts) if post: embed = self.create_embed(post) await self.bot.send_message(ctx.message.channel, embed=embed) else: msg = "All images have already been seen. Try again later." await self.bot.send_message(ctx.message.channel, msg) async def getRandomPost(self, posts): # create list of posts not recently seen. postlist = [] for post in posts: postid = post.find('id').text if postid not in self.cache.keys(): postlist.append(post) # if post list has posts, shuffle and select one if len(postlist): random.shuffle(postlist) postid = postlist[0].find('id').text self.cache[postid] = postid return postlist[0] else: return None def create_embed(self, post): post_url = post.find('file_url').text source_url = post.find("source").text postid = post.find("id").text orig_url = "https://e621.net/post/show/{id}".format(id=postid) embed = discord.Embed(title="\n", url=post_url, colour=0x006FFA) embed.set_image(url=post_url) embed.add_field( name="e261", value="*View on:* [[e261]]({o}) | [[Source]]({s})".format( o=orig_url, s=source_url), inline=False) return embed