class NetCDFData(Data): def __init__(self, url): self._dataset = None self.__timestamp_cache = TTLCache(1, 3600) super(NetCDFData, self).__init__(url) def __enter__(self): self._dataset = Dataset(self.url, 'r') return self def __exit__(self, exc_type, exc_value, traceback): self._dataset.close() @property def timestamps(self): if self.__timestamp_cache.get("timestamps") is None: var = None for v in ['time', 'time_counter']: if v in self._dataset.variables: var = self._dataset.variables[v] break t = netcdftime.utime(var.units) timestamps = np.array( map( lambda ts: t.num2date(ts).replace(tzinfo=pytz.UTC), var[:] ) ) timestamps.flags.writeable = False self.__timestamp_cache["timestamps"] = timestamps return self.__timestamp_cache.get("timestamps")
class IPBan: def __init__(self, limit=10, maxsize=16777216, ttl=None, add_day_to_key=False): if ttl is None: self.__submissions_per_ip = LRUCache(maxsize=maxsize) else: self.__submissions_per_ip = TTLCache(maxsize=maxsize, ttl=ttl) self.__add_day_to_key = add_day_to_key self.__limit = limit def __gen_key(self, ip): if self.__add_day_to_key: return ip, date.today() return ip def __get_ip(self, base_ip=None): if base_ip is not None: return base_ip return request.remote_addr def is_ok(self, ip=None): ip = self.__get_ip(ip) return self.__submissions_per_ip.get(self.__gen_key(ip), 0) <= self.__limit def incr(self, ip=None): ip = self.__get_ip(ip) key = self.__gen_key(ip) self.__submissions_per_ip[key] = self.__submissions_per_ip.get(key, 0) + 1
class Converter(CurrencyRates): def __init__(self, ttl=20.0, maxsize=256): super().__init__() self.cache = TTLCache(ttl=ttl, maxsize=maxsize) def get_cached_rates(self, input_currency): """ Method for caching exchange rates from forex library It returns all possible exchange rates for the given currency """ if self.cache.get(input_currency) is not None: return self.cache.get(input_currency) try: ret = self.get_rates(input_currency) ret[input_currency] = 1.0 except RatesNotAvailableError as e: ret = {} self.cache[input_currency] = ret return ret def convert(self, amount, input_currency, output_currency): """ Method for converting amount of input currency to output currency. If the output currency is None then it will output the amount in all possible currencies """ rates = self.get_cached_rates(input_currency) if output_currency is not None: rates = {output_currency: rates[output_currency]} rates = {r: rates[r] * amount for r in rates} return rates
class NetCDFData(Data): def __init__(self, url): self._dataset = None self.__timestamp_cache = TTLCache(1, 3600) super(NetCDFData, self).__init__(url) def __enter__(self): self._dataset = Dataset(self.url, 'r') return self def __exit__(self, exc_type, exc_value, traceback): self._dataset.close() @property def timestamps(self): if self.__timestamp_cache.get("timestamps") is None: var = None for v in ['time', 'time_counter']: if v in self._dataset.variables: var = self._dataset.variables[v] break t = netcdftime.utime(var.units) timestamps = np.array( map(lambda ts: t.num2date(ts).replace(tzinfo=pytz.UTC), var[:])) timestamps.flags.writeable = False self.__timestamp_cache["timestamps"] = timestamps return self.__timestamp_cache.get("timestamps")
class NetCDFData(Data): def __init__(self, url): self._dataset = None self.__timestamp_cache = TTLCache(1, 3600) self.interp = "gaussian" self.radius = 25000 self.neighbours = 10 super(NetCDFData, self).__init__(url) def __enter__(self): self._dataset = Dataset(self.url, 'r') return self def __exit__(self, exc_type, exc_value, traceback): self._dataset.close() """ Returns the value of a given variable name from the dataset """ def get_dataset_variable(self, key): return self._dataset.variables[key] """ Returns the file system path which was used to open the dataset """ def get_filepath(self): return self._dataset.filepath() """ Is the dataset open or closed? """ def is_open(self): return self._dataset.isopen() @property def timestamps(self): if self.__timestamp_cache.get("timestamps") is None: var = None for v in ['time', 'time_counter']: if v in self._dataset.variables: var = self._dataset.variables[v] break t = netcdftime.utime(var.units) timestamps = np.array( map(lambda ts: t.num2date(ts).replace(tzinfo=pytz.UTC), var[:])) timestamps.flags.writeable = False self.__timestamp_cache["timestamps"] = timestamps return self.__timestamp_cache.get("timestamps")
class RCache: def __init__(self): from cachetools import TTLCache self.c = TTLCache(1024, 60) async def aget(self, url, headers=None): self.c.expire() if not (res := self.c.get(url)): res = requests.get(url, headers=headers) self.c[url] = res return res
class Middleware(object): """ Falcon rate limiting middleware """ def __init__(self): self.count = goldman.config.RATE_LIMIT_COUNT self.duration = goldman.config.RATE_LIMIT_DURATION self.cache = TTLCache(maxsize=self.count, ttl=self.duration) @property def _error_headers(self): """ Return a dict of headers in every auth failure """ return { 'Retry-After': self.duration, 'X-RateLimit-Limit': self.count, 'X-RateLimit-Remaining': 0, } # pylint: disable=unused-argument def process_request(self, req, resp): """ Process the request before routing it. """ key = req.env['REMOTE_PORT'] + req.env['REMOTE_ADDR'] val = self.cache.get(key, 0) if val == self.count: abort(exceptions.TooManyRequests(headers=self._error_headers)) else: self.cache[key] = val + 1
class VerifyHashCache: """Cache handler to make it quick password check by bypassing already checked passwords against exact same couple of token/password. This cache handler is more efficient on small apps that run on few processes as cache is only shared between threads.""" def __init__(self): ttl = config_value("VERIFY_HASH_CACHE_TTL", default=(60 * 5)) max_size = config_value("VERIFY_HASH_CACHE_MAX_SIZE", default=500) try: from cachetools import TTLCache self._cache = TTLCache(max_size, ttl) except ImportError: # this should have been checked at app init. raise def has_verify_hash_cache(self, user): """Check given user id is in cache.""" return self._cache.get(user.id) def set_cache(self, user): """When a password is checked, then result is put in cache.""" self._cache[user.id] = True def clear(self): """Clear cache""" self._cache.clear()
class MemCache: def __init__(self, capacity, *indices): self._cache = TTLCache(capacity, TTL, timer=time) self._indices = indices def __call__(self, f): async def g(*args, **kwargs): key = hashkey(*[ ';'.join(args[idx]) if isinstance(args[idx], list ) else args[idx] for idx in self._indices ]) res = self._cache.get(key) if res is None: res = await f(*args, **kwargs) nocache = False try: nocache = res.disableCache except: pass if not nocache: self._cache[key] = res return res return g
class InMemoryCacheHandler: def __init__(self): dns_cache_size = int(setup_config_data_holder.dns_cache_size) dns_cache_ttl = int(setup_config_data_holder.dns_cache_ttl) self.cache = TTLCache(maxsize=dns_cache_size, ttl=dns_cache_ttl) def add_to_cache(self, key, value): self.cache[key] = value return self.cache def get_from_cache(self, key): return self.cache.get(key) def check_if_present(self, required_key): if required_key in self.cache.keys(): return True else: return False def show_cache(self): logger.debug(f"[Process: Show Cache], " f"CurrentCacheSize: [{self.cache.currsize}], " f"[MaxCacheSize: {self.cache.maxsize}], " f"[CacheTTL: {self.cache.ttl}], " f"[CacheTimer: {self.cache.timer}]") def clear_cache(self): logger.debug(f"[Process: Clearing Cache], " f"CurrentCacheSize: [{self.cache.currsize}], " f"[MaxCacheSize: {self.cache.maxsize}], " f"[CacheTTL: {self.cache.ttl}], " f"[CacheTimer: {self.cache.timer}]") self.cache.clear()
class LocalCache(BaseCache): def __init__(self, backend, cache_conf): try: max_size = cache_conf.get('max_size', 128) ttl = cache_conf.get('ttl', 86400) self.cache = TTLCache(maxsize=max_size, ttl=ttl) except Exception: raise ERROR_CACHE_CONFIGURATION(backend=backend) def get(self, key): return self.cache.get(key) def set(self, key, value, expire=None): if expire: raise ERROR_CACHE_OPTION(method='cache.set', option='expire') self.cache[key] = value return True def delete(self, *keys): for key in keys: del self.cache[key] def flush(self, is_async=False): self.cache.popitem() self.cache.expire()
class QiniuAuth(Auth): def __init__(self, access_key, secret_key, max_token_level=None): """ :param access_key: access_key :param secret_key: secret_key :param max_token_level: token缓存数最大水位 """ super().__init__(access_key, secret_key) if not max_token_level: max_token_level = CACHE_MAX_SIZE self._upload_token_cache = TTLCache(max_token_level, TTL) self._rtc_room_token_cache = TTLCache(max_token_level, TTL) def get_token(self, bucket, key=None, policy=None, strict_policy=True): token_key = f'{bucket}_{key}_{policy}_{strict_policy}' token = self._upload_token_cache.get(token_key, None) if not token: token = self.upload_token(bucket, key, expires=TTL + 100, policy=policy, strict_policy=strict_policy) self._upload_token_cache[token_key] = token return token def get_rtc_room_token(self, room_access): if isinstance(room_access, dict) and 'deadline' in room_access: room_access['deadline'] = int(time.time()) + TTL + 100 token_key = json.dumps(room_access) token = self._rtc_room_token_cache.get(token_key, None) if not token: token = self.token_with_data(token_key) self._rtc_room_token_cache[token_key] = token return token
class ChatBot: def __init__(self): self.api_key = chat_config.baidu_unit_api_key self.secret_key = chat_config.baidu_unit_secret_key self.bot_id = chat_config.baidu_unit_bot_id self.base_url = "https://aip.baidubce.com" self.token = "" self.sessions = TTLCache(maxsize=128, ttl=60 * 60 * 1) async def refresh_token(self): try: url = f"{self.base_url}/oauth/2.0/token?grant_type=client_credentials&client_id={self.api_key}&client_secret={self.secret_key}" async with httpx.AsyncClient() as client: resp = await client.post(url) result = resp.json() self.token = result.get("access_token", "") except Exception as e: logger.warning(f"Error in refresh_token: {e}") async def get_reply(self, text: str, event_id: str, user_id: str) -> str: if not self.token: await self.refresh_token() if not self.token: return "" url = f"{self.base_url}/rpc/2.0/unit/service/chat?access_token={self.token}" session_id = self.sessions.get(event_id) if not session_id: session_id = "" params = { "log_id": str(uuid.uuid4()), "version": "2.0", "service_id": self.bot_id, "session_id": session_id, "request": { "query": text, "user_id": user_id }, "dialog_state": { "contexts": { "SYS_REMEMBERED_SKILLS": ["1098087", "1098084", "1098089"] } }, } try: async with httpx.AsyncClient() as client: resp = await client.post(url, data=json.dumps(params, ensure_ascii=False)) result = resp.json() session_id = result["result"]["session_id"] self.sessions[event_id] = session_id return result["result"]["response_list"][0]["action_list"][0][ "say"] except Exception as e: logger.warning( f"Error in get_reply({text}, {event_id}, {user_id}): {e}") return ""
class AuthenticationMiddleware(object): def __init__(self, whitelist=[], maxsize=1000, ttl=60 * 60): self.whitelist = whitelist self._cache = TTLCache(maxsize=maxsize, ttl=ttl) def is_whitelisted(self, path): for unauthenticated_path in self.whitelist: if path[:len(unauthenticated_path)] == unauthenticated_path: logger.debug(f"Path is whitelisted: {path}") return True return False async def _get_user_and_permission(self, db, user_id): user, permission = self._cache.get(user_id, (None, None)) if user and permission: logger.debug('Using cached credentials') else: logger.debug(f"Getting credentials for '{user_id}'") user = await User.qs(db).get(user_id) permission = await Permission.qs(db).find_one( user={'$eq': user}) if user else None self._cache[user_id] = (user, permission) return user, permission async def _authenticate(self, context): scheme, token = context.request.headers['authorization'].split(' ') if scheme.lower() != 'bearer': raise Exception('invalid token') payload = jwt.decode(token, key=context.config.authentication.secret) if not payload['sub']: raise Exception('token contains no "sub"') user_id = payload['sub'] return await self._get_user_and_permission(context.db, user_id) async def resolve(self, next, root, info, *args, **kwargs): logger.debug(f"Authenticating {info.path}") try: if self.is_whitelisted(info.path): logger.debug(f"Path is whitelisted: {info.path}") elif 'user' in info.context: logger.debug(f"Path is already authenticated: {info.path}") else: logger.debug(f"Path requires authentication {info.path}") user, permission = await self._authenticate(edict(info.context) ) info.context['user'] = user info.context['permission'] = permission except: logger.debug(f'Failed to authenticate {info.path}') finally: response = await next(root, info, *args, **kwargs) return response
def test_atomic(self): cache = TTLCache(maxsize=1, ttl=1, timer=Timer(auto=True)) cache[1] = 1 self.assertEqual(1, cache[1]) cache[1] = 1 self.assertEqual(1, cache.get(1)) cache[1] = 1 self.assertEqual(1, cache.pop(1)) cache[1] = 1 self.assertEqual(1, cache.setdefault(1))
class RSSCache: def __init__(self, cache_max_size, cache_ttl_seconds): self.cache = TTLCache(ttl=cache_ttl_seconds, maxsize=cache_max_size) def update_cache(self, key, value): self.cache.update({key: value}) def get_from_cache(self, key): return self.cache.get(key)
class QueryCache(): def __init__(self, ttl=60 * 60): self.cache = TTLCache(float('inf'), ttl) def get(self, key): res = self.cache.get(key) if res: self.cache[key] = res return res def set(self, key, value): self.cache[key] = value
class StackCache: """堆栈缓存 使用运行内存作为高速缓存,可有效提高并发的处理能力 """ def __init__(self, maxsize=0xff, ttl=60): self._cache = TTLCache(maxsize, ttl) def has(self, key): return key in self._cache def get(self, key, default=None): return self._cache.get(key, default) def set(self, key, val): self._cache[key] = val def incr(self, key, val=1): res = self.get(key, 0) + val self.set(key, res) return res def decr(self, key, val=1): res = self.get(key, 0) - val self.set(key, res) return res def delete(self, key): del self._cache[key] def size(self): return len(self._cache) def clear(self): return self._cache.clear()
class LinkIService(IService): service_class = LinkService def __init__(self, **kwargs): super().__init__(**kwargs) self.__services_by_domain = TTLCache(maxsize=100, ttl=5) async def get_service_for(self, target_domain, ignore_services): services = self.__services_by_domain.get(target_domain) if services: services.rotate(-1) return services[0] else: services = [] event = asyncio.Event(loop=self.loop) queries = { service: asyncio.ensure_future( service.query(target_domain), loop=self.loop, ) for service in self.services if not service in ignore_services } def query_done(service, task): queries.pop(service) if not task.cancelled() and not task.exception(): services.append(service) for service, task in queries.items(): task.add_done_callback(partial(query_done, service)) while queries and not services: await asyncio.wait( queries.values(), return_when=asyncio.FIRST_COMPLETED, loop=self.loop, ) for task in queries.values(): task.cancel() if services: self.__services_by_domain[target_domain] = deque(services) return services[0]
class BaseCryptoExec: def __init__(self, loop=None, executor=None, cacheSize=128, cacheTTL=120): self.loop = loop if loop else asyncio.get_event_loop() self.executor = executor if executor else \ concurrent.futures.ThreadPoolExecutor(max_workers=4) self._keysCache = TTLCache(maxsize=cacheSize, ttl=cacheTTL) async def _exec(self, fn, *args, **kw): return await self.loop.run_in_executor( self.executor, functools.partial(fn, *args, **kw)) def cachedKey(self, keyId): return self._keysCache.get(keyId) def cacheKey(self, keyId, data): self._keysCache[keyId] = data
class WebSocket: def __init__(self, host, port, maxWebsockets, socketsTTL, processor): self.sockets=TTLCache(maxsize=maxWebsockets, ttl=socketsTTL) self.t=None self.host=host self.port=port self.processor=processor def start(self): self.t=Thread(target=self.serve_forever) self.t.start() logging.info("WEBSOCKET STARTED IN "+self.host+":"+str(self.port)) def stop(self): self.t._stop() def serve_forever(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) start_server = websockets.serve(self.handler, self.host, self.port) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() async def handler(self, websocket, path): while True: data = await websocket.recv() logging.info("WEBSOCKET RECEIVED "+data) jsonData=json.loads(data) token=jsonData["token"] if token in self.sockets: self.sockets[token].close() self.sockets[token]=websocket self.processor.checkPermissions(token) async def send(self, data, token): socket = self.sockets.get(token) if socket: await socket.send(data) def getUsers(self): return self.sockets.keys() def delToken(self, token): if token in self.sockets: self.sockets[token].close() del self.sockets[token]
class InternalCache: """Simple Caching structure""" def __init__(self, maxsize=float("inf"), time_to_live=24 * 60 * 60) -> None: self._cache = TTLCache(maxsize=maxsize, ttl=time_to_live) def update(self, url: str, file_path: str, value: str) -> None: key = hash((url, file_path)) self._cache[key] = value def get(self, url: str, file_path: str) -> Union[str, None]: key = hash((url, file_path)) return self._cache.get(key, None) def clear(self) -> None: self._cache.clear()
class CustomPoolManager(PoolManager): def __init__(self, *args, **kwargs): super(CustomPoolManager, self).__init__(*args, **kwargs) # custom: begins self.resolving_cache = TTLCache(maxsize=10000, ttl=(60 * 60)) self.pool_classes_by_scheme = { 'http': CustomHttpConnectionPool, 'https': CustomHttpsConnectionPool, } # custom: ends # custom: begins def get_custom_ip(self, host): return self.resolving_cache.get(host, None) # custom: ends def _new_pool(self, scheme, host, port, request_context=None): pool_cls = self.pool_classes_by_scheme[scheme] if request_context is None: request_context = self.connection_pool_kw.copy() # Although the context has everything necessary to create the pool, # this function has historically only used the scheme, host, and port # in the positional args. When an API change is acceptable these can # be removed. for key in ('scheme', 'host', 'port'): request_context.pop(key, None) if scheme == 'http': for kw in SSL_KEYWORDS: request_context.pop(kw, None) # custom: begins custom_ip = self.get_custom_ip(host) if custom_ip: request_context['custom_ip'] = self.resolving_cache[host] # custom: ends return pool_cls(host, port, **request_context)
class IPBlock(object): def __init__(self, app, read_preference=None, cache_size=None, cache_ttl=None, blocking_enabled=True, logging_enabled=False): """ Initialize IPBlock and set up a before_request handler in the app. You can override the default MongoDB read preference via the optional read_preference kwarg. You can limit the impact of the IP checks on your MongoDB by maintaining a local in-memory LRU cache. To do so, specify its cache_size (i.e. max number of IP addresses it can store) and cache_ttl (i.e. how many seconds each result should be cached for). To run in dry-run mode without blocking requests, set blocking_enabled to False. Set logging_enabled to True to log IPs that match blocking rules -- if enabled, will log even if blocking_enabled is False. """ self.read_preference = read_preference self.blocking_enabled = blocking_enabled self.logger = None if logging_enabled: self.logger = app.logger self.block_msg = "blocking" if blocking_enabled else "blocking disabled" if cache_size and cache_ttl: # inline import because cachetools dependency is optional. from cachetools import TTLCache self.cache = TTLCache(cache_size, cache_ttl) else: self.cache = None app.before_request(self.block_before) def block_before(self): """ Check the current request and block it if the IP address it's coming from is blacklisted. """ # To avoid unnecessary database queries, ignore the IP check for # requests for static files if request.path.startswith(url_for('static', filename='')): return # Some static files might be served from the root path (e.g. # favicon.ico, robots.txt, etc.). Ignore the IP check for most # common extensions of those files. ignored_extensions = ('ico', 'png', 'txt', 'xml') if request.path.rsplit('.', 1)[-1] in ignored_extensions: return ips = request.headers.getlist('X-Forwarded-For') if not ips: return # If the X-Forwarded-For header contains multiple comma-separated # IP addresses, we're only interested in the last one. ip = ips[0].strip() if ip[-1] == ',': ip = ip[:-1] ip = ip.rsplit(',', 1)[-1].strip() if self.matches_ip(ip): if self.logger is not None: self.logger.info("IPBlock: matched {}, {}".format(ip, self.block_msg)) if self.blocking_enabled: return 'IP Blocked', 200 def matches_ip(self, ip): """Return True if the given IP is blacklisted, False otherwise.""" # Check the cache if caching is enabled if self.cache is not None: matches_ip = self.cache.get(ip) if matches_ip is not None: return matches_ip # Query MongoDB to see if the IP is blacklisted matches_ip = IPNetwork.matches_ip( ip, read_preference=self.read_preference) # Cache the result if caching is enabled if self.cache is not None: self.cache[ip] = matches_ip return matches_ip
class LDAPAuthenticationBackend(object): CAPABILITIES = (AuthBackendCapability.CAN_AUTHENTICATE_USER, AuthBackendCapability.HAS_USER_INFORMATION, AuthBackendCapability.HAS_GROUP_INFORMATION) def __init__(self, bind_dn, bind_password, base_ou, group_dns, host, port=389, scope='subtree', id_attr='uid', use_ssl=False, use_tls=False, cacert=None, network_timeout=10.0, chase_referrals=False, debug=False, client_options=None, group_dns_check='and', cache_user_groups_response=True, cache_user_groups_cache_ttl=120, cache_user_groups_cache_max_size=100): if not bind_dn: raise ValueError( 'Bind DN to query the LDAP server is not provided.') if not bind_password: raise ValueError( 'Password for the bind DN to query the LDAP server is not provided.' ) if not host: raise ValueError('Hostname for the LDAP server is not provided.') self._bind_dn = bind_dn self._bind_password = bind_password self._host = host if port: self._port = port elif not port and not use_ssl: LOG.warn('Default port 389 is used for the LDAP query.') self._port = 389 elif not port and use_ssl: LOG.warn('Default port 636 is used for the LDAP query over SSL.') self._port = 636 if use_ssl and use_tls: raise ValueError('SSL and TLS cannot be both true.') if cacert and not os.path.isfile(cacert): raise ValueError( 'Unable to find the cacert file "%s" for the LDAP connection.' % (cacert)) self._use_ssl = use_ssl self._use_tls = use_tls self._cacert = cacert self._network_timeout = network_timeout self._chase_referrals = chase_referrals self._debug = debug self._client_options = client_options if not id_attr: LOG.warn( 'Default to "uid" for the user attribute in the LDAP query.') if not base_ou: raise ValueError('Base OU for the LDAP query is not provided.') if scope not in SEARCH_SCOPES.keys(): raise ValueError('Scope value for the LDAP query must be one of ' '%s.' % str(SEARCH_SCOPES.keys())) self._id_attr = id_attr or 'uid' self._base_ou = base_ou self._scope = SEARCH_SCOPES[scope] if not group_dns: raise ValueError('One or more user groups must be specified.') if group_dns_check not in VALID_GROUP_DNS_CHECK_VALUES: valid_values = ', '.join(VALID_GROUP_DNS_CHECK_VALUES) raise ValueError( 'Invalid value "%s" for group_dns_check option. Valid values are: ' '%s.' % (group_dns_check, valid_values)) self._group_dns_check = group_dns_check self._group_dns = group_dns self._cache_user_groups_response = cache_user_groups_response self._cache_user_groups_cache_ttl = int(cache_user_groups_cache_ttl) self._cache_user_groups_cache_max_size = int( cache_user_groups_cache_max_size) # Cache which stores LDAP groups response for a particular user if self._cache_user_groups_response: self._user_groups_cache = TTLCache( maxsize=self._cache_user_groups_cache_max_size, ttl=self._cache_user_groups_cache_ttl) else: self._user_groups_cache = None def authenticate(self, username, password): connection = None if not password: raise ValueError('password cannot be empty') try: # Instantiate connection object and bind with service account. try: connection = self._init_connection() connection.simple_bind_s(self._bind_dn, self._bind_password) except Exception: LOG.exception('Failed to bind with "%s".' % self._bind_dn) return False # Search for user and fetch the DN of the record. try: user_dn = self._get_user_dn(connection=connection, username=username) except ValueError as e: LOG.exception(str(e)) return False except Exception: LOG.exception('Unexpected error when querying for user "%s".' % username) return False # Search if user is member of pre-defined groups. try: user_groups = self._get_groups_for_user(connection=connection, user_dn=user_dn, username=username) # Assume group entries are not case sensitive. user_groups = set([entry.lower() for entry in user_groups]) required_groups = set( [entry.lower() for entry in self._group_dns]) result = self._verify_user_group_membership( username=username, required_groups=required_groups, user_groups=user_groups, check_behavior=self._group_dns_check) if not result: return False except Exception: LOG.exception( 'Unexpected error when querying membership for user "%s".' % username) return False self._clear_connection(connection) # Authenticate with the user DN and password. try: connection = self._init_connection() connection.simple_bind_s(user_dn, password) LOG.info('Successfully authenticated user "%s".' % username) return True except Exception: LOG.exception('Failed authenticating user "%s".' % username) return False except ldap.LDAPError: LOG.exception('Unexpected LDAP error.') return False finally: self._clear_connection(connection) return False def get_user(self, username): """ Retrieve user information. :rtype: ``dict`` """ connection = None try: connection = self._init_connection() connection.simple_bind_s(self._bind_dn, self._bind_password) _, user_info = self._get_user(connection=connection, username=username) except Exception: LOG.exception('Failed to retrieve details for user "%s"' % (username)) return None finally: self._clear_connection(connection) user_info = dict(user_info) return user_info def get_user_groups(self, username): """ Return a list of all the groups user is a member of. :rtype: ``list`` of ``str`` """ # First try to get result from a local in-memory cache groups = self._get_user_groups_from_cache(username=username) if groups is not None: return groups connection = None try: connection = self._init_connection() connection.simple_bind_s(self._bind_dn, self._bind_password) user_dn = self._get_user_dn(connection=connection, username=username) groups = self._get_groups_for_user(connection=connection, user_dn=user_dn, username=username) except Exception: LOG.exception('Failed to retrieve groups for user "%s"' % (username)) return None finally: self._clear_connection(connection) # Store result in cache (if caching is enabled) self._set_user_groups_in_cache(username=username, groups=groups) return groups def _init_connection(self): """ Initialize connection to the LDAP server. """ # Use CA cert bundle to validate certificate if present. if self._use_ssl or self._use_tls: if self._cacert: ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self._cacert) else: ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) if self._debug: trace_level = 2 else: trace_level = 0 # Setup connection and options. protocol = 'ldaps' if self._use_ssl else 'ldap' endpoint = '%s://%s:%d' % (protocol, self._host, int(self._port)) connection = ldap.initialize(endpoint, trace_level=trace_level) connection.set_option(ldap.OPT_DEBUG_LEVEL, 255) connection.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3) connection.set_option(ldap.OPT_NETWORK_TIMEOUT, self._network_timeout) if self._chase_referrals: connection.set_option(ldap.OPT_REFERRALS, 1) else: connection.set_option(ldap.OPT_REFERRALS, 0) client_options = self._client_options or {} for option_name, option_value in client_options.items(): connection.set_option(int(option_name), option_value) if self._use_tls: connection.start_tls_s() return connection def _clear_connection(self, connection): """ Unbind and close connection to the LDAP server. """ if connection: connection.unbind_s() def _get_user_dn(self, connection, username): user_dn, _ = self._get_user(connection=connection, username=username) return user_dn def _get_user(self, connection, username): """ Retrieve LDAP user record for the provided username. Note: This method escapes ``username`` so it can safely be used as a filter in the query. :rtype: ``tuple`` (``user_dn``, ``user_info_dict``) """ username = ldap.filter.escape_filter_chars(username) query = '%s=%s' % (self._id_attr, username) result = connection.search_s(self._base_ou, self._scope, query, []) if result: entries = [entry for entry in result if entry[0] is not None] else: entries = [] if len(entries) <= 0: msg = ('Unable to identify user for "%s".' % (query)) raise ValueError(msg) if len(entries) > 1: msg = ('More than one users identified for "%s".' % (query)) raise ValueError(msg) user_tuple = entries[0] return user_tuple def _get_groups_for_user(self, connection, user_dn, username): """ Return a list of all the groups user is a member of. :rtype: ``list`` of ``str`` """ # First try to get result from a local in memory cache groups = self._get_user_groups_from_cache(username=username) if groups is not None: return groups filter_values = [user_dn, user_dn, username] query = ldap.filter.filter_format(USER_GROUP_MEMBERSHIP_QUERY, filter_values) result = connection.search_s(self._base_ou, self._scope, query, []) if result: groups = [entry[0] for entry in result if entry[0] is not None] else: groups = [] # Store result in cache (if caching is enabled) self._set_user_groups_in_cache(username=username, groups=groups) return groups def _verify_user_group_membership(self, username, required_groups, user_groups, check_behavior='and'): """ Validate that the user is a member of required groups based on the check behavior defined in the config (and / or). """ if check_behavior == 'and': additional_msg = ( 'user needs to be member of all the following groups "%s" for ' 'authentication to succeeed') elif check_behavior == 'or': additional_msg = ( 'user needs to be member of one or more of the following groups "%s" ' 'for authentication to succeeed') additional_msg = additional_msg % (str(list(required_groups))) LOG.debug('Verifying user group membership using "%s" behavior (%s)' % (check_behavior, additional_msg)) if check_behavior == 'and': if required_groups.issubset(user_groups): return True elif check_behavior == 'or': if required_groups.intersection(user_groups): return True msg = ( 'Unable to verify membership for user "%s (required_groups=%s,' 'actual_groups=%s,check_behavior=%s)".' % (username, str(required_groups), str(user_groups), check_behavior)) LOG.exception(msg) # Final safe guard return False def _get_user_groups_from_cache(self, username): """ Get value from per-user group cache (if caching is enabled). """ if not self._cache_user_groups_response: return None LOG.debug('Getting LDAP groups for user "%s" from cache' % (username)) result = self._user_groups_cache.get(username, None) if result is None: LOG.debug('LDAP groups cache for user "%s" is empty' % (username)) else: LOG.debug('Found LDAP groups cache for user "%s"' % (username)) return result def _set_user_groups_in_cache(self, username, groups): """ Store value in per-user group cache (if caching is enabled). """ if not self._cache_user_groups_response: return None LOG.debug('Storing groups for user "%s" in cache' % (username)) self._user_groups_cache[username] = groups
class Quotes(commands.Cog): def __init__(self, bot): self.bot = bot self.db = self.bot.db # Used to prevent duplicate messages in quick succession. self.cache = TTLCache(ttl=300, maxsize=120) @staticmethod def int_or_string(arg): """ Returns: True: arg is an integer (a quote ID) False: arg is a string (a username) """ try: _ = int(arg) return True except: return False @staticmethod def _create_embed(quote): """Parse quote and create discord.py embed.""" embed = discord.Embed().from_dict({ "title": "\n", "description": quote.message, "color": quote.member.top_role.color, "author": { "name": quote.member.display_name, "icon_url": str(quote.member.user.avatar_url) }, "footer": { 'text': quote.timestamp.strftime("%b %d %Y") } }) if quote.attachments: embed.set_image( url=urllib.parse.urljoin(config.HOST + "files/", quote.attachments[0])) return embed async def _add_quote(self, ctx, quote: QuoteData): """Add a quote to the database""" quote = Quote(**quote.to_dict()) self.db.add(quote) self.db.commit() self.db.refresh(quote) return quote async def _remove_quote(self, ctx, message): """Remove a quote from the database.""" self.db.query(Quote).filter_by(guild_id=ctx.guild.id) \ .filter_by(message_id=message.id) \ .delete() self.db.commit() await ctx.channel.send(REMOVED_MSG) async def _get_quote(self, ctx, user_id=None, quote_id=None, limit=None): """ Get quote from the current server. All keywords are optional. Keyword Arguments: user_id: retrieves from a specific user quote_id: retrieves a specific quote limit: limit query to limit-n of quotes. """ filters = [Quote.guild_id == ctx.guild.id] if user_id: filters += [Quote.user_id == user_id] elif quote_id: filters += [Quote.message_id == quote_id] quote = self.db.query(Quote).filter(and_(*filters)) \ .order_by(func.random()) if limit: quote = quote.limit(limit).all() else: quote = quote.first() return quote async def _get_quote_user(self, ctx, username): """ Get user who send quote from the database. Attempts to use findUser. Otherwise falls back to quering the guild_membership from the database. Returns user_id of user found or None. """ # find user using findUser user = await findUser(ctx, username) if user: return user.id # fallback and try to find user in database. user = self.db.query(GuildMembership).filter( and_(GuildMembership.display_name == username, GuildMembership.guild_id == ctx.guild.id)).first() if user: return user.user_id else: return None @commands.guild_only() @commands.command(name='quote') async def quote(self, ctx, *args): """ Send a message from the quote database. First argument can either be a user's id or a message id. Possible Arguments: user_id: Retrieves a quote from a specific user quote_id: Retrieves a specific quote """ user = None quote = None if args: is_int = self.int_or_string(args[0]) if is_int: # arg is a quote_id quote = await self._get_quote(ctx, quote_id=args[0]) if not quote: await ctx.channel.send(NOQUOTE_MSG) return embed = self._create_embed(quote) await ctx.channel.send(embed=embed) return else: # arg is a username user = await self._get_quote_user(ctx, args[0]) if not user: await ctx.channel.send(NOUSER_MSG) return # limit retries from cached duplicates. quotes = await self._get_quote(ctx, user_id=user, limit=20) if not quotes: await ctx.channel.send(NORESULTS_MSG) return for quote in quotes: # check to see if quote was already sent recently cached_quote = self.cache.get(quote.message_id, None) if not cached_quote: self.cache[quote.message_id] = quote break else: quote = None if quote: embed = self._create_embed(quote) await ctx.channel.send(embed=embed) else: await ctx.channel.send(ALLSEEN_MSG) @commands.guild_only() @admin_only() @commands.command(name="add_quote") async def quote_add(self, ctx, *args): """ Add quote Command Takes an arbitrary number of message IDs and adds them as quotes in the database If multiple ids are provided, messages are merged. """ # no message ids given if not args: return data = QuoteData() try: # group messages. for i, arg in enumerate(args): # first message is the main quote message. if i == 0: data.root_message = await ctx.channel.fetch_message( int(arg)) # any additional messages are child messages. else: m = await ctx.channel.fetch_message(int(arg)) data.child_messages.append(m) # if there is only one quote message if not data.child_messages: # If there are message attachments, save them # and add file path to list of attachments. if data.root_message.attachments: for a in data.root_message.attachments: filename = f"{str(time.time())}_{a.filename}" fp = Path(f"./public/files/{filename}") await a.save(fp) data.attachments.append(filename) # multiple quote messages else: # no attachments in multi-quote messages. (for now) if data.root_message.attachments: raise QuoteAttachmentError for m in data.child_messages: # no attachments in multi-quote messages. (for now) if m.attachments: raise QuoteAttachmentError # Check that all messages retrieved are from the same user. if m.author.id != data.root_message.author.id: raise QuoteAuthorError quote = await self._add_quote(ctx, data) embed = self._create_embed(quote) await ctx.channel.send(embed=embed) except QuoteResultError: await ctx.channel.send(NORESULTS_MSG) except QuoteAuthorError: await ctx.channel.send(SAMEUSER_MSG) except QuoteAttachmentError: await ctx.channel.send(NOATTACHMENTS_MSG) @commands.guild_only() @admin_only() @commands.command(name="remove_quote") async def remove(self, ctx, *, message_id: int): """Remove quote command""" message = await ctx.channel.fetch_message(message_id) if message: await self._remove_quote(ctx, message) else: await ctx.channel.send(NORESULTS_MSG) @commands.command(name="cleo") async def cleo(self, ctx, *args): """Command to post a link to a server's quote page""" await ctx.channel.send(config.HOST_URL + "quotes/" + str(ctx.guild.id))
class SlackTeam: def __init__(self, data: dict, slack: 'Slack'): self.slack = slack self.bot = self.slack.bot self.team_id = data['team_id'] self.token = data['token'] self.bot_id = data['bot_id'] self.name = self.team_id self.app = AsyncApp(token=self.token) self.app.view("socket_modal_submission")(self.submission) self.app.event("message")(self.slack_message) self.app.event("member_joined_channel")(self.slack_member_joined) self.app.event("channel_left")(self.slack_channel_left) self.handler = AsyncSocketModeHandler(self.app, config.SLACK_APP_TOKEN) self.bot.loop.create_task(self.handler.start_async()) self.bot.add_listener(self.on_message, 'on_message') self.bot.add_listener(self.on_raw_message_edit, 'on_raw_message_edit') self.bot.add_listener(self.on_raw_message_delete, 'on_raw_message_delete') self.channels: list[SlackChannel] = [] self.members: list[SlackMember] = [] self.slack.bot.loop.create_task(self.get_team_info()) self.discord_messages: TTLCache[int, DiscordMessage] = TTLCache(ttl=600.0, maxsize=500) self.slack_messages: TTLCache[str, SlackMessage] = TTLCache(ttl=600.0, maxsize=500) self.message_links = TTLCache(ttl=86400.0, maxsize=1000) self.initialize_data() self.messages_cached = asyncio.Event() self.members_cached = asyncio.Event() self.channels_cached = asyncio.Event() self.slack.bot.loop.create_task(self.cache_members()) self.slack.bot.loop.create_task(self.cache_channels()) self.slack.bot.loop.create_task(self.cache_messages()) def initialize_data(self): """Initilises the data in the database if needed.""" data = db.slack_bridge.find_one({'team_id': self.team_id}) if not data: data = { 'team_id': self.team_id, 'aliases': [], 'bridges': [], 'tokens': {} } db.slack_bridge.insert_one(data) async def get_team_info(self): team_data = await self.app.client.team_info() self.name = team_data['team']['name'] async def get_channels(self) -> list[dict]: """Function for getting channels, makes call to slack api and filters out channels bot isnt member of.""" channels_data = await self.app.client.conversations_list( team_id=self.team_id) return [ channel for channel in channels_data['channels'] if channel['is_member'] ] async def get_members(self) -> list[dict]: """Function for getting members, makes call to slack api and filters out bot accounts and slack bot account.""" members_data = await self.app.client.users_list(team_id=self.team_id) return [ member for member in members_data['members'] if not member['is_bot'] and not member['id'] == 'USLACKBOT' ] async def add_user(self, user_id: str): user_data = await self.app.client.users_info(user=user_id) slack_member = SlackMember(user_data['user'], self.slack) await slack_member.get_discord_member() self.members[slack_member.team_id].append(slack_member) return slack_member def get_user(self, slack_id: str = None, *, discord_id: int = None) -> Optional[SlackMember]: """Get SlackMember via slack id or discord id.""" for member in self.members: if (slack_id is not None and member.id == slack_id) or \ (discord_id is not None and member.discord_member and member.discord_member.id == discord_id): return member def get_channel(self, slack_id: str = None, *, discord_id: int = None) -> Optional[SlackChannel]: """Get SlackChannel via slack id or discord id.""" for channel in self.channels: if (slack_id is not None and channel.id == slack_id) or \ (discord_id is not None and channel.discord_channel and channel.discord_channel.id == discord_id): return channel async def cache_messages(self): """Cache messages in the database.""" await self.members_cached.wait() await self.channels_cached.wait() messages = db.slack_messages.find({ 'team_id': self.team_id }).sort('timestamp', 1) for message in messages: twenty_four_hours = 24 * 60 * 60 if time.time() - message['timestamp'] > twenty_four_hours: db.slack_messages.delete_one(message) continue if message['origin'] == 'discord': self.message_links[message['discord_message_id']] = message[ 'slack_message_id'] elif message['origin'] == 'slack': self.message_links[message['slack_message_id']] = message[ 'discord_message_id'] self.messages_cached.set() self.slack.bot.logger.debug( f'{len(self.message_links)} Slack and Discord messages cached for team [{self.team_id}]' ) async def cache_channels(self): """Caches channels.""" channels = await self.get_channels() for channel_data in channels: if self.get_channel(channel_data['id']): continue channel = SlackChannel( team=self, channel_id=channel_data['id'], slack=self.slack, ) self.channels.append(channel) self.channels_cached.set() self.slack.bot.logger.debug( f'{len(channels)} Slack channels cached for team [{self.team_id}]') async def cache_members(self): """Caches members.""" members = await self.get_members() for member_data in members: if self.get_user(member_data['id']): continue member = SlackMember( data=member_data, slack=self.slack, ) self.slack.bot.loop.create_task(member.get_discord_member()) self.members.append(member) self.members_cached.set() self.slack.bot.logger.debug( f'{len(members)} Slack member cached for team [{self.team_id}]') async def delete_discord_message(self, channel_id: int, message_id: int, *, ts: str = None): await self.slack.bot.http.delete_message(channel_id, message_id) if ts in self.slack_messages: del self.slack_messages[ts] db.slack_messages.delete_one({'slack_message_id': ts}) async def delete_slack_message(self, message_id: str, discord_channel_id: int, *, discord_message_id: int = None): slack_channel = self.get_channel(discord_id=discord_channel_id) await self.app.client.chat_delete(channel=slack_channel.id, ts=message_id) if discord_message_id in self.discord_messages: del self.discord_messages[discord_message_id] db.slack_messages.delete_one( {'discord_message_id': discord_message_id}) async def get_slack_message( self, channel_id: str, message_id: str, discord_message_id: int = None) -> Optional[SlackMessage]: if message_id is None: return result = await self.app.client.conversations_history( channel=channel_id, inclusive=True, oldest=message_id, limit=1) if not result or not result['messages']: return message = result['messages'][0] data = { 'event': { 'team_id': message['team_id'], 'user': message['user'] if 'user' in message else message['username'], 'discord_message_id': discord_message_id, 'channel': channel_id, 'text': message['text'], 'ts': message_id, 'files': [], 'subtype': '' } } return SlackMessage(data, self.slack) async def get_discord_message( self, channel_id: int, message_id: int, slack_message_id: str = None) -> Optional[DiscordMessage]: channel = self.slack.bot.get_channel(channel_id) try: message = await channel.fetch_message(message_id) except: # most-likely errors when message has been deleted db.slack_messages.delete_one({'discord_message_id': message_id}) return if message: discord_message = DiscordMessage(message, self.slack) discord_message.slack_message_id = slack_message_id return discord_message else: db.slack_messages.delete_one({'discord_message_id': message_id}) async def submission(self, ack): """Function that acknowledges events.""" await ack() async def slack_channel_left(self, body: dict): """Function called when bot leaves a channel""" event = body['event'] channel_id = event['channel'] channel = self.get_channel(channel_id) if channel: db.slack_bridge.update_one( {'team_id': self.team_id}, {'$pull': { 'bridges': { 'slack_channel_id': channel_id } }}) self.channels.remove(channel) async def slack_member_joined(self, body: dict): event = body['event'] channel_id = event['channel'] user_id = event['user'] if user_id == self.bot_id: channel = SlackChannel(self, channel_id, self.slack) self.channels.append(channel) else: user_data = await self.app.client.users_info(user=user_id) member = SlackMember(user_data['user'], self.slack) self.members.append(member) async def handle_delete_message(self, body: dict): event = body['event'] ts = event['deleted_ts'] channel_id = event['channel'] slack_message = self.slack_messages.get(ts, None) slack_message_link = self.message_links.get(ts, None) if not slack_message: slack_channel = self.get_channel(slack_id=channel_id) if not slack_channel.discord_channel or not slack_message_link: return return await self.delete_discord_message( slack_channel.discord_channel.id, slack_message_link, ts=ts) return await slack_message.delete() async def handle_edit_message(self, body: dict): event = body['event'] ts = event['message']['ts'] channel_id = event['channel'] # check if message was edited by the bot edit_message = next( filter(lambda dm: dm.slack_message_id == ts, self.discord_messages.values()), None) if edit_message: return slack_message = self.slack_messages.get(ts, None) if not slack_message: slack_message_link = self.message_links[ ts] if ts in self.message_links else None slack_message = await self.get_slack_message( channel_id, ts, slack_message_link) if slack_message: slack_message.text = event['message']['text'] asyncio.create_task(slack_message.send_to_discord(edit=True)) async def slack_message(self, body): """Function called on message even from slack.""" await self.messages_cached.wait() event = body['event'] is_delete = 'subtype' in body['event'] and body['event'][ 'subtype'] == 'message_deleted' is_edit = 'subtype' in body['event'] and body['event'][ 'subtype'] == 'message_changed' is_bot_message = 'subtype' in body['event'] and body['event'][ 'subtype'] == 'bot_message' if is_edit: return await self.handle_edit_message(event) if is_delete: return await self.handle_delete_message(event) if is_bot_message: return message = SlackMessage(body, self.slack) asyncio.create_task(message.send_to_discord()) # Discord events async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent): channel_id = payload.channel_id slack_channel = self.get_channel(discord_id=channel_id) if not slack_channel: return team = slack_channel.team await team.messages_cached.wait() message_id = payload.message_id discord_message = team.discord_messages.get(message_id, None) if not discord_message: discord_message_link = team.message_links[message_id] return await team.delete_slack_message( discord_message_link, channel_id, discord_message_id=message_id) await discord_message.delete() async def on_raw_message_edit(self, payload: discord.RawMessageUpdateEvent): """ Function called on on_edit_message event, used for dealing with message edit events on the discord side and doing the same on the slack end. """ if 'content' not in payload.data: return channel_id = payload.channel_id slack_channel = self.get_channel(discord_id=channel_id) if not slack_channel: return team = slack_channel.team await team.messages_cached.wait() message_id = payload.message_id content = payload.data['content'] cached_message = team.discord_messages.get(message_id, None) if not cached_message: cached_message_link = team.message_links.get(message_id, None) cached_message = await team.get_discord_message( channel_id, message_id, cached_message_link) if cached_message is None: return cached_message.text = content await cached_message.send_to_slack(edit=True) async def on_message(self, message: discord.Message): """Function call on on_message event, used for identifying discord bridge channel and forwarding the messages to slack.""" # ignore webhook messages and pms if not message.guild or message.webhook_id: return slack_channel = self.get_channel(discord_id=message.channel.id) if not slack_channel: return await self.messages_cached.wait() discord_message = DiscordMessage(message, self.slack) await discord_message.send_to_slack()
class KodiClient(object): def __init__(self, config): self._cache = TTLCache(maxsize=2048, ttl=3600) if 'user' in config and 'password' in config: self.auth = (config['user'], config['password']) else: self.auth = None self.host = config['host'] self.port = config['port'] self.chunk_size = 750 self._api = Server( url='http://{host}:{port}/jsonrpc'.format(**config), auth=self.auth) def _make_generator(self, method, data_key, **params): logger.debug("Fetching first chunk of {}".format(data_key)) params.update({'limits': {'start': 0, 'end': self.chunk_size}}) resp = method(**params) for d in resp[data_key]: yield d num_total = resp['limits']['total'] cur_start = self.chunk_size while cur_start < num_total: params['limits']['start'] = cur_start params['limits']['end'] = cur_start + self.chunk_size logger.debug("Fetching next chunk from #{}".format(cur_start)) resp = method(**params) for d in resp[data_key]: yield d cur_start += self.chunk_size @cached() def get_artists(self): artists = list(self._make_generator( self._api.AudioLibrary.GetArtists, 'artists', properties=PROPERTIES['artist'])) self._cache.update({'artist.{}'.format(a['artistid']): a for a in artists}) return artists def get_artist(self, artist_id): artist_id = int(artist_id) cached = self._cache.get('artist.{}'.format(artist_id)) if cached is None: try: artist = self._api.AudioLibrary.GetArtistDetails( artistid=artist_id, properties=PROPERTIES['artist'])['artistdetails'] self._cache['artist.{}'.format(artist_id)] = artist return artist except Exception as e: return None else: return cached @cached() def get_albums(self, artist_id=None, recently_added=False): if recently_added: return self._api.AudioLibrary.GetRecentlyAddedAlbums( properties=PROPERTIES['album'])['albums'] if artist_id is not None: artist_id = int(artist_id) params = {'properties': PROPERTIES['album'], 'data_key': 'albums'} if artist_id: params['filter'] = {'artistid': artist_id} albums = list(self._make_generator( self._api.AudioLibrary.GetAlbums, **params)) self._cache.update({'album.{}'.format(a['albumid']): a for a in albums}) return albums def get_album(self, album_id): album_id = int(album_id) cached = self._cache.get('album.{}'.format(album_id)) if cached is None: try: album = self._api.AudioLibrary.GetAlbumDetails( albumid=album_id, properties=PROPERTIES['album'])['albumdetails'] self._cache['album.{}'.format(album_id)] = album return album except Exception as e: self._cache['album.{}'.format(album_id)] = None return None else: return cached @cached() # First-level cache for accessing all tracks def get_songs(self, album_id=None): if album_id is not None: album_id = int(album_id) params = {'properties': PROPERTIES['song'], 'data_key': 'songs'} if album_id: params['filter'] = {'albumid': album_id} songs = list(self._make_generator( self._api.AudioLibrary.GetSongs, **params)) # Second level cache so that get_song doesn't have to make an API call self._cache.update({'song.{}'.format(s['songid']): s for s in songs}) return songs def get_song(self, song_id): song_id = int(song_id) cached = self._cache.get('song.{}'.format(song_id)) if cached is None: try: song = self._api.AudioLibrary.GetSongDetails( songid=song_id, properties=PROPERTIES['song'])['songdetails'] self._cache['song.{}'.format(song_id)] = song return song except Exception as e: self._cache['song.{}'.format(song_id)] = None return None else: return cached @cached() def get_url(self, filepath): path = self._api.Files.PrepareDownload(filepath) url = "http://{}{}:{}/{}".format( "{}:{}@".format(*self.auth) if self.auth else '', self.host, self.port, path['details']['path']) self._cache['trackurl.{}'.format(url)] = filepath return url
class FuturesTrader: def __init__(self): self.client: AsyncClient = None self.state: dict = None self.prices: dict = {} self.symbols: dict = {} self.price_streamer = None self.clocks = NamedLock() self.olock = asyncio.Lock() # lock to place only one order at a time self.slock = asyncio.Lock() # lock for stream subscriptions self.order_queue = asyncio.Queue() # cache to disallow orders with same symbol, entry and first TP for 10 mins self.sig_cache = TTLCache(maxsize=100, ttl=600) self.balance = 0 self.results_handler = None async def init(self, api_key, api_secret, state={}, test=False, loop=None): self.state = state self.client = await AsyncClient.create(api_key=api_key, api_secret=api_secret, testnet=test, loop=loop) self.manager = BinanceSocketManager(self.client, loop=loop) self.user_stream = UserStream(api_key, api_secret, test=test) if not self.state.get("streams"): self.state["streams"] = [] if not self.state.get("orders"): self.state["orders"] = {} await self._gather_orders() await self._watch_orders() await self._subscribe_futures_user() resp = await self.client.futures_exchange_info() for info in resp["symbols"]: self.symbols[info["symbol"]] = info resp = await self.client.futures_account_balance() for item in resp: if item["asset"] == "USDT": self.balance = float(item["balance"]) logging.info(f"Account balance: {self.balance} USDT", on="blue") async def queue_signal(self, signal: Signal): await self.order_queue.put(signal) async def close_trades(self, tag, coin=None): if coin is None: logging.info( f"Attempting to close all trades associated with channel {tag}", color="yellow") else: logging.info( f"Attempting to close {coin} trades associated with channel {tag}", color="yellow") async with self.olock: removed = [] for order_id, order in self.state["orders"].items(): if order.get("tag") != tag: continue if coin is not None and order["sym"] != f"{coin}USDT": continue children = [] + order["t_ord"] if order.get("s_ord"): children.append(order["s_ord"]) removed += children for oid in children: await self._cancel_order(oid, order["sym"]) quantity = 0 for tid, q in zip(order["t_ord"], order["t_q"]): if not self.state["orders"].get(tid, {}).get("filled"): quantity += q try: if quantity > 0: resp = await self.client.futures_create_order( symbol=order["sym"], positionSide="LONG" if order["side"] == "BUY" else "SHORT", side="SELL" if order["side"] == "BUY" else "BUY", type=OrderType.MARKET, quantity=self._round_qty(order["sym"], quantity), ) else: resp = await self.client.futures_cancel_order( symbol=order["sym"], origClientOrderId=order_id, ) logging.info( f"Closed position for order {order}, resp: {resp}", color="yellow") except Exception as err: logging.error( f"Failed to close position for order {order}, err: {err}" ) for oid in removed: self.state["orders"].pop(oid, None) if not removed: logging.info( f"Didn't find any matching positions for {tag} to close", color="yellow") async def _gather_orders(self): async def _gatherer(): logging.info("Waiting for orders to be queued...") while True: signal = await self.order_queue.get() if self.symbols.get(f"{signal.coin}USDT") is None: logging.info(f"Unknown symbol {signal.coin} in signal", color="yellow") continue async def _process(signal): # Process one order at a time for each symbol async with self.clocks.lock(signal.coin): registered = await self._register_order_for_signal( signal) if not registered: logging.info( f"Ignoring signal from {signal.tag} because order exists " f"for {signal.coin}", color="yellow") return for i in range(ORDER_MAX_RETRIES): try: await self._place_order(signal) return except PriceUnavailableException: logging.info( f"Price unavailable for {signal.coin}", color="red") except EntryCrossedException as err: logging.info( f"Price went too fast ({err.price}) for signal {signal}", color="yellow") except InsufficientQuantityException as err: logging.info( f"Allocated ${round(err.alloc_funds, 2)} for {err.alloc_q} {signal.coin} " f"but requires ${round(err.est_funds, 2)} for {err.est_q} {signal.coin}", color="red") except Exception as err: logging.error( f"Failed to place order: {traceback.format_exc()} {err}" ) break # unknown error - don't block future signals if i < ORDER_MAX_RETRIES - 1: await asyncio.sleep(ORDER_RETRY_SLEEP) await self._unregister_order(signal) asyncio.ensure_future(_process(signal)) asyncio.ensure_future(_gatherer()) async def _place_order(self, signal: Signal): await self._subscribe_futures(signal.coin) for _ in range(10): if self.prices.get(signal.coin) is not None: break logging.info(f"Waiting for {signal.coin} price to be available") await asyncio.sleep(1) if self.prices.get(signal.coin) is None: raise PriceUnavailableException() symbol = f"{signal.coin}USDT" logging.info(f"Modifying leverage to {signal.leverage}x for {symbol}", color="green") await self.client.futures_change_leverage(symbol=symbol, leverage=signal.leverage) price = self.prices[signal.coin] signal.correct(price) alloc_funds = self.balance * signal.fraction quantity = alloc_funds / (price / signal.leverage) logging.info(f"Corrected signal: {signal}", color="cyan") qty = self._round_qty(symbol, quantity) est_funds = qty * signal.entry / signal.leverage if (est_funds / alloc_funds) > PRICE_SLIPPAGE: raise InsufficientQuantityException(quantity, alloc_funds, qty, est_funds) order_id = OrderID.wait() if signal.wait_entry else OrderID.market() params = { "symbol": symbol, "positionSide": "LONG" if signal.is_long else "SHORT", "side": "BUY" if signal.is_long else "SELL", "type": OrderType.MARKET, "quantity": qty, "newClientOrderId": order_id, } if (signal.is_long and price > signal.max_entry) or ( signal.is_short and price < signal.max_entry): raise EntryCrossedException(price) if signal.wait_entry: logging.info( f"Placing stop limit order for {signal.coin} (price @ {price}, entry @ {signal.entry})" ) params["type"] = OrderType.STOP params["stopPrice"] = self._round_price(symbol, signal.entry) params["price"] = self._round_price(symbol, signal.max_entry) else: logging.info( f"Placing market order for {signal.coin} (price @ {price}, entry @ {signal.entry}" ) async with self.olock: # Lock only for interacting with orders try: resp = await self.client.futures_create_order(**params) self.state["orders"][order_id] = { "id": resp["orderId"], "qty": float(resp["origQty"]), "sym": symbol, "side": params["side"], "ent": signal.entry if signal.wait_entry else price, "sl": signal.sl, "tgt": signal.targets, "fnd": alloc_funds, "lev": signal.leverage, "tag": signal.tag, "crt": int(time.time()), "t_ord": [], "t_q": [], } logging.info(f"Created order {order_id} for signal: {signal}, " f"params: {json.dumps(params)}, resp: {resp}") except Exception as err: logging.error( f"Failed to create order for signal {signal}: {err}, " f"params: {json.dumps(params)}") if isinstance(err, BinanceAPIException) and err.code == -2021: raise EntryCrossedException(price) async def _place_collection_orders(self, order_id): await self._place_sl_order(order_id) async with self.olock: odata = self.state["orders"][order_id] await self.results_handler( Trade.entry(odata["tag"], odata["sym"], odata["ent"], odata["qty"], odata["lev"], odata["side"])) if odata.get("t_ord"): logging.warning( f"TP order(s) already exist for parent {order_id}") return targets = odata["tgt"][:MAX_TARGETS] remaining = quantity = odata["qty"] for i, tgt in enumerate(targets): quantity *= 0.5 # NOTE: Don't close position (as it'll affect other orders) if i == len(targets) - 1: quantity = self._round_qty(odata["sym"], remaining) else: quantity = self._round_qty(odata["sym"], quantity) tgt_order_id = await self._create_target_order( order_id, odata["sym"], odata["side"], tgt, quantity) if tgt_order_id is None: continue odata["t_ord"].append(tgt_order_id) odata["t_q"].append(quantity) self.state["orders"][tgt_order_id] = { "parent": order_id, "filled": False, } remaining -= quantity async def _create_target_order(self, order_id, symbol, side, tgt_price, rounded_qty): tgt_order_id = OrderID.target() params = { "symbol": symbol, "type": OrderType.LIMIT, "timeInForce": "GTC", "positionSide": "LONG" if side == "BUY" else "SHORT", "side": "SELL" if side == "BUY" else "BUY", "newClientOrderId": tgt_order_id, "price": self._round_price(symbol, tgt_price), "quantity": rounded_qty, } try: resp = await self.client.futures_create_order(**params) logging.info( f"Created limit order {tgt_order_id} for parent {order_id}, " f"resp: {resp}, params: {json.dumps(params)}") return tgt_order_id except Exception as err: logging.error( f"Failed to create target order for parent {order_id}: {err}, " f"params: {json.dumps(params)}") async def _subscribe_futures_user(self): async def _handler(): while True: async with self.user_stream.message() as msg: try: data = msg event = msg['e'] if event == UserEventType.AccountUpdate: data = msg["a"] elif event == UserEventType.OrderTradeUpdate: data = msg["o"] elif event == UserEventType.AccountConfigUpdate: data = msg.get("ac", msg.get("ai")) logging.info(f"{event}: {data}") await self._handle_event(msg) except Exception as err: logging.exception( f"Failed to handle event {msg}: {err}") asyncio.ensure_future(_handler()) async def _handle_event(self, msg: dict): if msg["e"] == UserEventType.AccountUpdate: for info in msg["a"]["B"]: if info["a"] == "USDT": self.balance = float(info["cw"]) logging.info(f"Account balance: {self.balance} USDT", on="blue") elif msg["e"] == UserEventType.OrderTradeUpdate: info = msg["o"] order_id = info["c"] async with self.olock: o = self.state["orders"].get(order_id) if o is None: logging.warning( f"Received order {order_id} but missing in state") return if info["X"] == "FILLED": if OrderID.is_wait(order_id) or OrderID.is_market(order_id): entry = float(info["ap"]) logging.info( f"Placing TP/SL orders for fulfilled order {order_id} (entry: {entry})", color="green") async with self.olock: self.state["orders"][order_id]["ent"] = entry await self._place_collection_orders(order_id) elif OrderID.is_stop_loss(order_id): async with self.olock: logging.info( f"Order {order_id} hit stop loss. Removing TP orders...", color="red") sl = self.state["orders"].pop(order_id) parent = self.state["orders"].pop(sl["parent"]) for oid in parent["t_ord"]: self.state["orders"].pop( oid, None) # It might not exist await self._cancel_order(oid, parent["sym"]) await self.results_handler( Trade.target(parent["tag"], parent["sym"], parent["ent"], parent["qty"], parent["lev"], float(info["ap"]), float(info["q"]), is_long=parent["side"] == "BUY", is_sl=True)) elif OrderID.is_target(order_id): logging.info(f"TP order {order_id} hit.", color="green") await self._move_stop_loss(order_id) async def _move_stop_loss(self, tp_id: str): async with self.olock: tp = self.state["orders"][tp_id] tp["filled"] = True parent = self.state["orders"][tp["parent"]] targets = parent["t_ord"] if tp_id not in targets: if parent.get("s_ord") is None: logging.warning(f"SL doesn't exist for order {parent}") return logging.warning( f"Couldn't find TP order {tp_id} in parent {parent}, closing trade", color="red") await self.close_trades(parent["tag"], parent["sym"].replace("USDT", "")) return idx = targets.index(tp_id) await self.results_handler( Trade.target(parent["tag"], parent["sym"], parent["ent"], parent["qty"], parent["lev"], parent["tgt"][idx], parent["t_q"][idx], is_long=parent["side"] == "BUY")) if tp_id == targets[-1]: logging.info(f"All TP orders hit. Removing parent {parent}") parent = self.state["orders"].pop(tp["parent"]) for oid in parent["t_ord"]: self.state["orders"].pop(oid, None) # It might not exist self.state["orders"].pop(parent["s_ord"]) await self._cancel_order(parent["s_ord"], parent["sym"]) return else: new_price, quantity = parent["ent"], sum(parent["t_q"][(idx + 1):]) await self._place_sl_order(tp["parent"], new_price, quantity) async def _place_sl_order(self, parent_id: str, new_price=None, quantity=None): async with self.olock: odata = self.state["orders"][parent_id] symbol = odata["sym"] sl_order_id = OrderID.stop_loss() if odata.get("s_ord") is not None: logging.info( f"Moving SL order for {parent_id} to new price {new_price}" ) await self._cancel_order(odata["s_ord"], symbol) params = { "symbol": symbol, "positionSide": "LONG" if odata["side"] == "BUY" else "SHORT", "side": "SELL" if odata["side"] == "BUY" else "BUY", "type": OrderType.STOP_MARKET, "newClientOrderId": sl_order_id, "stopPrice": self._round_price( symbol, new_price if new_price is not None else odata["sl"]), "quantity": self._round_qty( symbol, (quantity if quantity is not None else odata["qty"])), } for _ in range(2): try: resp = await self.client.futures_create_order(**params) odata["s_ord"] = sl_order_id self.state["orders"][sl_order_id] = { "parent": parent_id, "filled": False, } logging.info( f"Created SL order {sl_order_id} for parent {parent_id}, " f"resp: {resp}, params: {json.dumps(params)}") break except Exception as err: logging.error( f"Failed to create SL order for parent {parent_id}: {err}, " f"params: {json.dumps(params)}") if isinstance( err, BinanceAPIException ) and err.code == -2021: # price is around SL now logging.info( f"Placing market order for parent {parent_id} " "after attempt to create SL order", color="yellow") params.pop("stopPrice") params["type"] = OrderType.MARKET async def _watch_orders(self): async def _watcher(): while True: try: open_symbols = await self._expire_outdated_orders_and_get_open_symbols( ) except Exception as err: logging.exception( f"Failed to expire outdated orders: {err}") async with self.slock: redundant = set( self.state["streams"]).difference(open_symbols) if redundant: logging.warning( f"Resetting price streams to {open_symbols}", color="yellow") self.state["streams"] = open_symbols await self._subscribe_futures(resub=redundant) for sym in redundant: self.prices.pop(f"{sym}USDT", None) now = time.time() async with self.olock: removed = [] for order_id, order in self.state["orders"].items(): if not OrderID.is_wait(order_id): continue if order.get("t_ord"): continue if now < (order["crt"] + WAIT_ORDER_EXPIRY): continue logging.warning( f"Wait order {order_id} has expired. Removing...", color="yellow") removed.append(order_id) await self._cancel_order(order_id, order["sym"]) for order_id in removed: self.state["orders"].pop(order_id) await asyncio.sleep(ORDER_WATCH_INTERVAL) asyncio.ensure_future(_watcher()) async def _expire_outdated_orders_and_get_open_symbols(self): open_symbols = [] open_orders, positions = {}, {} sl_orders = [] async with self.olock: resp = await self.client.futures_account() for pos in resp["positions"]: amount, side = float(pos["positionAmt"]), pos["positionSide"] if side == "BOTH" or amount == 0: continue positions[pos["symbol"] + side] = amount resp = await self.client.futures_get_open_orders() for order in resp: open_orders[order["clientOrderId"]] = order logging.info( f"Checking {len(open_orders)} orders for {len(positions)} positions: {positions}" ) for oid, order in open_orders.items(): odata = self.state["orders"].get(oid) if OrderID.is_market(oid) or OrderID.is_wait(oid): open_symbols.append(order["symbol"][:-4]) elif odata and (OrderID.is_target(oid) or OrderID.is_stop_loss(oid)): odata["filled"] = False removed = [] for oid, odata in list(self.state["orders"].items()): if (OrderID.is_target(oid) or OrderID.is_stop_loss(oid)): if self.state["orders"].get(odata["parent"]) is None: logging.warning( f"Order {oid} is now an orphan. Flagging for removal" ) removed.append(oid) if not (OrderID.is_wait(oid) or OrderID.is_market(oid)): continue if open_orders.get(oid) is not None: # only filled orders continue if (time.time() - odata["crt"]) < NEW_ORDER_TIMEOUT: # must be old continue if odata.get("s_ord") is None: # must have stop loss order continue side = "LONG" if odata["side"] == "BUY" else "SHORT" if not positions.get(odata["sym"] + side): logging.warning( f"Order {oid} missing in open positions. Flagging for removal" ) removed.append(oid) continue children = [odata["s_ord"]] + odata["t_ord"] for cid in children: if open_orders.get(cid) is not None: # ignore open orders continue is_filled = False try: cdata = await self.client.futures_get_order( symbol=odata["sym"], origClientOrderId=cid) is_filled = cdata["status"] == "FILLED" except Exception as err: logging.warning( f"Error fetching order {cid} for parent {odata}: {err}" ) if is_filled: self.state["orders"][cid] = { "parent": oid, "filled": True, } else: logging.info( f"Missing order {cid} detected for parent {oid}", color="yellow") self.state["orders"].pop(cid, None) order_hit = [] for i in range(len(children)): cdata = self.state["orders"].get(children[i]) order_hit.append(cdata["filled"] if cdata else False) if cdata is not None: continue if i == 0: sl_orders.append({"id": oid}) continue tgt_id = await self._create_target_order( oid, odata["sym"], odata["side"], odata["tgt"][i - 1], odata["t_q"][i - 1]) if tgt_id is not None: odata["t_ord"][i - 1] = tgt_id self.state["orders"][tgt_id] = { "parent": oid, "filled": False, } if order_hit[0] or all(order_hit[1:]): # All TPs or SL hit removed.append(oid) for oid in removed: logging.warning(f"Removing outdated order {oid}", color="yellow") parent = self.state["orders"].pop(oid, None) if not (OrderID.is_wait(oid) or OrderID.is_market(oid)): continue if parent is None: continue sym = parent["sym"] for cid in [parent["s_ord"]] + parent["t_ord"]: logging.warning(f"Removing outdated child {cid}", color="yellow") c = self.state["orders"].pop(cid, None) if c and not c["filled"]: await self._cancel_order(cid, sym) for o in sl_orders: oid = o["id"] for i, tid in enumerate(self.state["orders"][oid]["t_ord"]): if self.state["orders"][tid]["filled"]: o["tp"] = tid break for o in sl_orders: if o.get("tp"): await self._move_stop_loss(o["tp"]) else: await self._place_sl_order(o["id"]) return open_symbols async def _subscribe_futures(self, coin: str = None, resub=False): async with self.slock: num_streams = len(set(self.state["streams"])) resub = resub or (self.price_streamer is None) if coin is not None: coin = coin.upper() # We should have duplicates because it should be possible to long/short # on top of an existing long/short. self.state["streams"].append(coin) if num_streams != len(set(self.state["streams"])): resub = True if resub and self.price_streamer is not None: logging.info("Cancelling ongoing ws stream for resubscribing") self.price_streamer.cancel() symbols = set(self.state["streams"]) if not resub or not symbols: return async def _streamer(): subs = list(map(lambda s: s.lower() + "usdt@aggTrade", symbols)) logging.info( f"Spawning listener for {len(symbols)} symbol(s): {symbols}", color="magenta") async with self.manager.futures_multiplex_socket(subs) as stream: while True: msg = await stream.recv() if msg is None: logging.warning("Received 'null' in price stream", color="red") continue try: symbol = msg["stream"].split("@")[0][:-4].upper() self.prices[symbol.upper()] = float(msg["data"]["p"]) except Exception as err: logging.error( f"Failed to get price for {msg['stream']}: {err}") self.price_streamer = asyncio.ensure_future(_streamer()) async def _cancel_order(self, oid: str, symbol: str): try: resp = await self.client.futures_cancel_order( symbol=symbol, origClientOrderId=oid) logging.info(f"Cancelled order {oid}: {resp}") except Exception as err: logging.error(f"Failed to cancel order {oid}: {err}") async def _register_order_for_signal(self, signal: Signal): async with self.olock: self.sig_cache.expire() # Same provider can't give signal for same symbol within 20 seconds key = self._cache_key(signal) if self.sig_cache.get(key) is not None: return False self.sig_cache[key] = () return True async def _unregister_order(self, signal: Signal): async with self.olock: self.sig_cache.pop(self._cache_key(signal), None) def _cache_key(self, signal: Signal): return f"{signal.coin}_{signal.targets[0]}" # coin and first target for filter # MARK: Rouding for min quantity and min price for symbols def _round_price(self, symbol: str, price: float): info = self.symbols[symbol] for f in info["filters"]: if f["filterType"] == "PRICE_FILTER": return round( price, int(round(math.log(1 / float(f["tickSize"]), 10), 0))) return price def _round_qty(self, symbol: str, qty: float): info = self.symbols[symbol] for f in info["filters"]: if f["filterType"] == "LOT_SIZE": return round( qty, int(round(math.log(1 / float(f["minQty"]), 10), 0))) return qty
class WSNewBlocksWatcher(BaseWatcher): MESSAGE_TIMEOUT = 30.0 PING_TIMEOUT = 10.0 def __init__(self, w3: Web3, websocket_url): super().__init__(w3) self._network_on = False self._nonce: int = 0 self._current_block_number: int = -1 self._websocket_url = websocket_url self._node_address = None self._client: Optional[websockets.WebSocketClientProtocol] = None self._fetch_new_blocks_task: Optional[asyncio.Task] = None self._block_cache = TTLCache(maxsize=10, ttl=120) _nbw_logger: Optional[HummingbotLogger] = None @classmethod def logger(cls) -> HummingbotLogger: if cls._nbw_logger is None: cls._nbw_logger = logging.getLogger(__name__) return cls._nbw_logger @property def block_number(self) -> int: return self._current_block_number @property def block_cache(self) -> Dict[HexBytes, AttributeDict]: cache_dict: Dict[HexBytes, AttributeDict] = dict([(key, self._block_cache[key]) for key in self._block_cache.keys()]) return cache_dict async def start_network(self): if self._fetch_new_blocks_task is not None: await self.stop_network() else: try: self._current_block_number = await self.call_async(getattr, self._w3.eth, "blockNumber") except asyncio.CancelledError: raise except Exception: self.logger().network("Error fetching newest Ethereum block number.", app_warning_msg="Error fetching newest Ethereum block number. " "Check Ethereum node connection", exc_info=True) await self.connect() await self.subscribe(["newHeads"]) self._fetch_new_blocks_task: asyncio.Task = safe_ensure_future(self.fetch_new_blocks_loop()) self._network_on = True async def stop_network(self): if self._fetch_new_blocks_task is not None: await self.disconnect() self._fetch_new_blocks_task.cancel() self._fetch_new_blocks_task = None self._network_on = False async def connect(self): try: self._client = await websockets.connect(uri=self._websocket_url) return self._client except Exception as e: self.logger().network(f"ERROR in connection: {e}") async def disconnect(self): try: await self._client.close() except Exception as e: self.logger().network(f"ERROR in disconnection: {e}") async def _send(self, emit_data) -> int: self._nonce += 1 emit_data["id"] = self._nonce await self._client.send(ujson.dumps(emit_data)) return self._nonce async def subscribe(self, params) -> bool: emit_data = { "method": "eth_subscribe", "params": params } nonce = await self._send(emit_data) raw_message = await self._client.recv() if raw_message is not None: resp = ujson.loads(raw_message) if resp.get("id", None) == nonce: self._node_address = resp.get("result") return True return False async def _messages(self) -> AsyncIterable[Any]: try: while True: try: raw_msg_str: str = await asyncio.wait_for(self._client.recv(), self.MESSAGE_TIMEOUT) yield raw_msg_str except asyncio.TimeoutError: try: pong_waiter = await self._client.ping() await asyncio.wait_for(pong_waiter, timeout=self.PING_TIMEOUT) except asyncio.TimeoutError: raise except asyncio.TimeoutError: self.logger().warning("WebSocket ping timed out. Going to reconnect...") return except ConnectionClosed: return finally: await self.disconnect() async def fetch_new_blocks_loop(self): while True: try: async for raw_message in self._messages(): message_json = ujson.loads(raw_message) if raw_message is not None else None if message_json.get("method", None) == "eth_subscription": subscription_result_params = message_json.get("params", None) incoming_block = subscription_result_params.get("result", None) \ if subscription_result_params is not None else None if incoming_block is not None: new_block: AttributeDict = await self.call_async(self._w3.eth.getBlock, incoming_block.get("hash"), True) self._current_block_number = new_block.get("number") self._block_cache[new_block.get("hash")] = new_block self.trigger_event(NewBlocksWatcherEvent.NewBlocks, [new_block]) except asyncio.TimeoutError: self.logger().network("Timed out fetching new block.", exc_info=True, app_warning_msg="Timed out fetching new block. " "Check wallet network connection") except asyncio.CancelledError: raise except BlockNotFound: pass except Exception as e: self.logger().network(f"Error fetching new block: {e}", exc_info=True, app_warning_msg="Error fetching new block. " "Check wallet network connection") await asyncio.sleep(30.0) async def get_timestamp_for_block(self, block_hash: HexBytes, max_tries: Optional[int] = 10) -> int: counter = 0 block: AttributeDict = None if block_hash in self._block_cache.keys(): block = self._block_cache.get(block_hash) else: while block is None: if counter == max_tries: raise ValueError(f"Block hash {block_hash.hex()} does not exist.") counter += 1 block = self._block_cache.get(block_hash) await asyncio.sleep(0.5) return block.get("timestamp")
from cachetools import TTLCache cache = TTLCache(maxsize=10, ttl=3) cache['apple1'] = 'top dog' cache['apple2'] = 'top dog' cache['apple3'] = 'top dog' cache['apple4'] = 'top dog' cache['apple5'] = 'top dog' cache['apple6'] = 'top dog' cache['apple7'] = 'top dog' cache['apple8'] = 'top dog' cache['apple9'] = 'top dog' cache['apple10'] = 'top dog' pprint(cache) print("Sleeping for 2 Seconds") time.sleep(2) print( "Fetching the data again to check if get will preserver log fro more time") print(cache.get('apple2')) pprint(cache) cache["mango1"] = "mango 2" print("Sleeping for 2 Seconds") time.sleep(2) print("After 2 Seconds") pprint(cache)
class Github(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.files_regex = re.compile(r"\s{0,}```\w{0,}\s{0,}") self.token_cache = TTLCache(maxsize=1000, ttl=600) @property def session(self): return self.bot.http._HTTPClient__session async def cog_check(self, ctx: commands.Context): token = self.token_cache.get(ctx.author.id) if not token: user = await UserModel.get_or_none(id=ctx.author.id) if ctx.command != self.link_github and ( user is None or user.github_oauth_token is None): raise GithubNotLinkedError() token = user.github_oauth_token self.token_cache[ctx.author.id] = token ctx.gh_token = token return True @commands.command(name="linkgithub", aliases=["lngithub"]) async def link_github(self, ctx: commands.Context): expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=120) url = "https://github.com/login/oauth/authorize?" + urlencode({ "client_id": github_oauth_config.client_id, "scope": "gist", "redirect_uri": "https://tech-struck.vercel.app/oauth/github", "state": jwt.encode({ "id": ctx.author.id, "expiry": str(expiry) }, config.secret), }) await ctx.author.send(embed=Embed( title="Connect Github", description= f"Click [this]({url}) to link your github account. This link invalidates in 2 minutes", )) @commands.command(name="creategist", aliases=["crgist"]) async def create_gist(self, ctx: commands.Context, *, inp): """ Create gists from within discord Example: filename.py ``` # Codeblock with contents of filename.py ``` filename2.txt ``` Codeblock containing filename2.txt's contents ``` """ files_and_names = self.files_regex.split(inp)[:-1] # Dict comprehension to create the files 'object' files = { name: { "content": content + "\n" } for name, content in zip(files_and_names[0::2], files_and_names[1::2]) } req = await self.github_request(ctx, "POST", "/gists", json={"files": files}) res = await req.json() # TODO: Make this more verbose to the user and log errors await ctx.send(res.get("html_url", "Something went wrong.")) @commands.command(name="githubsearch", aliases=["ghsearch", "ghse"]) async def github_search(self, ctx: commands.Context, *, term: str): # TODO: Docs req = await self.github_request(ctx, "GET", "/search/repositories", dict(q=term, per_page=5)) data = await req.json() if not data["items"]: return await ctx.send(embed=Embed( title=f"Searched for {term}", color=Color.red(), description="No results found", )) em = Embed( title=f"Searched for {term}", color=Color.green(), description="\n\n".join([ "[{0[owner][login]}/{0[name]}]({0[html_url]})\n{0[stargazers_count]:,} :star:\u2800{0[forks_count]} \u2387\u2800\n{1}" .format(result, self.repo_desc_format(result)) for result in data["items"] ]), ) await ctx.send(embed=em) @commands.command(name="githubstats", aliases=["ghstats", "ghst"]) async def github_stats(self, ctx, username="******", theme="radical"): theme = theme.lower() themes = "default dark radical merko gruvbox tokyonight onedark cobalt synthwave highcontrast dracula".split( " ") if theme not in themes: return await ctx.send( "Not a valid theme. List of all valid themes:- default, dark, radical, merko, gruvbox, tokyonight, onedark, cobalt, synthwave, highcontrast, dracula" ) url = "https://github-readme-stats.codestackr.vercel.app/api?" + urlencode( { "username": username, "show_icons": "true", "hide_border": "true", "theme": theme, }) url = f"https://github-readme-stats.codestackr.vercel.app/api?username={username}&show_icons=true&hide_border=true&theme={theme}" file = await self.get_file_from_svg_url(url, exclude=[b"A++", b"A+"]) await ctx.send(file=discord.File(file, filename="stats.png")) @commands.command(name="githublanguages", aliases=["ghlangs", "ghtoplangs"]) async def github_top_languages(self, ctx, username="******", theme="radical"): theme = theme.lower() themes = "default dark radical merko gruvbox tokyonight onedark cobalt synthwave highcontrast dracula".split( " ") if theme not in themes: return await ctx.send( "Not a valid theme. List of all valid themes:- default, dark, radical, merko, gruvbox, tokyonight, onedark, cobalt, synthwave, highcontrast, dracula" ) url = ( "https://github-readme-stats.codestackr.vercel.app/api/top-langs/?" + urlencode({ "username": username, "theme": theme })) file = await self.get_file_from_svg_url(url) await ctx.send(file=discord.File(file, filename="langs.png")) async def get_file_from_svg_url(self, url, exclude=[], fmt="PNG"): res = await (await self.session.get(url)).content.read() for i in exclude: res = res.replace( i, b"" ) # removes everything that needs to be excluded (eg. the uncentered A+) drawing = svg2rlg(BytesIO(res)) file = BytesIO(renderPM.drawToString(drawing, fmt=fmt)) return file @staticmethod def repo_desc_format(result): description = result["description"] if not description: return "" return description if len(description) < 100 else (description[:100] + "...") def github_request( self, ctx: commands.Context, req_type: str, endpoint: str, params: dict = None, json: dict = None, ): return self.session.request( req_type, f"https://api.github.com{endpoint}", params=params, json=json, headers={"Authorization": f"Bearer {ctx.gh_token}"}, )
class MySQLStatementSamples(object): """ Collects statement samples and execution plans. """ executor = ThreadPoolExecutor() def __init__(self, check, config, connection_args): self._check = check self._version_processed = False self._connection_args = connection_args # checkpoint at zero so we pull the whole history table on the first run self._checkpoint = 0 self._log = get_check_logger() self._last_check_run = 0 self._db = None self._tags = None self._tags_str = None self._service = "mysql" self._collection_loop_future = None self._cancel_event = threading.Event() self._rate_limiter = ConstantRateLimiter(1) self._config = config self._db_hostname = resolve_db_host(self._config.host) self._enabled = is_affirmative( self._config.statement_samples_config.get('enabled', False)) self._run_sync = is_affirmative( self._config.statement_samples_config.get('run_sync', False)) self._collections_per_second = self._config.statement_samples_config.get( 'collections_per_second', -1) self._events_statements_row_limit = self._config.statement_samples_config.get( 'events_statements_row_limit', 5000) self._explain_procedure = self._config.statement_samples_config.get( 'explain_procedure', 'explain_statement') self._fully_qualified_explain_procedure = self._config.statement_samples_config.get( 'fully_qualified_explain_procedure', 'datadog.explain_statement') self._events_statements_temp_table = self._config.statement_samples_config.get( 'events_statements_temp_table_name', 'datadog.temp_events') self._events_statements_enable_procedure = self._config.statement_samples_config.get( 'events_statements_enable_procedure', 'datadog.enable_events_statements_consumers') self._preferred_events_statements_tables = EVENTS_STATEMENTS_PREFERRED_TABLES self._has_window_functions = False events_statements_table = self._config.statement_samples_config.get( 'events_statements_table', None) if events_statements_table: if events_statements_table in DEFAULT_EVENTS_STATEMENTS_COLLECTIONS_PER_SECOND: self._log.debug( "Configured preferred events_statements_table: %s", events_statements_table) self._preferred_events_statements_tables = [ events_statements_table ] else: self._log.warning( "Invalid events_statements_table: %s. Must be one of %s. Falling back to trying all tables.", events_statements_table, ', '.join(DEFAULT_EVENTS_STATEMENTS_COLLECTIONS_PER_SECOND. keys()), ) self._explain_strategies = { 'PROCEDURE': self._run_explain_procedure, 'FQ_PROCEDURE': self._run_fully_qualified_explain_procedure, 'STATEMENT': self._run_explain, } self._preferred_explain_strategies = [ 'PROCEDURE', 'FQ_PROCEDURE', 'STATEMENT' ] self._init_caches() self._statement_samples_client = _new_statement_samples_client() def _init_caches(self): self._collection_strategy_cache = TTLCache( maxsize=self._config.statement_samples_config.get( 'collection_strategy_cache_maxsize', 1000), ttl=self._config.statement_samples_config.get( 'collection_strategy_cache_ttl', 300), ) # explained_statements_cache: limit how often we try to re-explain the same query self._explained_statements_cache = TTLCache( maxsize=self._config.statement_samples_config.get( 'explained_statements_cache_maxsize', 5000), ttl=60 * 60 / self._config.statement_samples_config.get( 'explained_statements_per_hour_per_query', 60), ) # seen_samples_cache: limit the ingestion rate per (query_signature, plan_signature) self._seen_samples_cache = TTLCache( # assuming ~100 bytes per entry (query & plan signature, key hash, 4 pointers (ordered dict), expiry time) # total size: 10k * 100 = 1 Mb maxsize=self._config.statement_samples_config.get( 'seen_samples_cache_maxsize', 10000), ttl=60 * 60 / self._config.statement_samples_config.get( 'samples_per_hour_per_query', 15), ) def run_sampler(self, tags): """ start the sampler thread if not already running & update tag metadata :param tags: :return: """ if not self._enabled: self._log.debug("Statement sampler not enabled") return self._tags = tags self._tags_str = ','.join(tags) for t in self._tags: if t.startswith('service:'): self._service = t[len('service:'):] if not self._version_processed and self._check.version: self._has_window_functions = self._check.version.version_compatible( (8, 0, 0)) if self._check.version.flavor == "MariaDB" or not self._check.version.version_compatible( (5, 7, 0)): self._global_status_table = "information_schema.global_status" else: self._global_status_table = "performance_schema.global_status" self._version_processed = True self._last_check_run = time.time() if self._run_sync or is_affirmative( os.environ.get('DBM_STATEMENT_SAMPLER_RUN_SYNC', "false")): self._log.debug("Running statement sampler synchronously") self._collect_statement_samples() elif self._collection_loop_future is None or not self._collection_loop_future.running( ): self._collection_loop_future = MySQLStatementSamples.executor.submit( self.collection_loop) else: self._log.debug( "Statement sampler collection loop already running") def cancel(self): """ Cancels the collection loop thread if it's running. Returns immediately, leaving the thread to stop & clean up on its own time. """ self._cancel_event.set() def _get_db_connection(self): """ lazy reconnect db pymysql connections are not thread safe so we can't reuse the same connection from the main check :return: """ if not self._db: self._db = pymysql.connect(**self._connection_args) return self._db def _close_db_conn(self): if self._db: try: self._db.close() except Exception: self._log.debug("Failed to close db connection", exc_info=1) finally: self._db = None def collection_loop(self): try: self._log.info("Starting statement sampler collection loop") while True: if self._cancel_event.isSet(): self._log.info("Collection loop cancelled") self._check.count( "dd.mysql.statement_samples.collection_loop_cancel", 1, tags=self._tags) break if time.time( ) - self._last_check_run > self._config.min_collection_interval * 2: self._log.info( "Stopping statement sampler collection loop due to check inactivity" ) self._check.count( "dd.mysql.statement_samples.collection_loop_inactive_stop", 1, tags=self._tags) break self._collect_statement_samples() except pymysql.err.DatabaseError as e: self._log.warning( "Statement sampler database error: %s", e, exc_info=self._log.getEffectiveLevel() == logging.DEBUG) self._check.count( "dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:collection-loop-database-error-{}".format(type(e))], ) except Exception as e: self._log.exception("Statement sampler collection loop crash") self._check.count( "dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:collection-loop-crash-{}".format(type(e))], ) finally: self._log.info("Shutting down statement sampler collection loop") self._close_db_conn() def _cursor_run(self, cursor, query, params=None, obfuscated_params=None): """ Run and log the query. If provided, obfuscated params are logged in place of the regular params. """ self._log.debug("Running query [%s] %s", query, obfuscated_params if obfuscated_params else params) cursor.execute(query, params) def _get_new_events_statements(self, events_statements_table, row_limit): # Select the most recent events with a bias towards events which have higher wait times start = time.time() drop_temp_table_query = "DROP TEMPORARY TABLE IF EXISTS {}".format( self._events_statements_temp_table) params = (self._checkpoint, row_limit) with closing(self._get_db_connection().cursor( pymysql.cursors.DictCursor)) as cursor: # silence expected warnings to avoid spam cursor.execute('SET @@SESSION.sql_notes = 0') try: self._cursor_run(cursor, drop_temp_table_query) self._cursor_run( cursor, CREATE_TEMP_TABLE.format( temp_table=self._events_statements_temp_table, statements_table="performance_schema." + events_statements_table, ), params, ) except pymysql.err.DatabaseError as e: self._check.count( "dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:create-temp-table-{}".format(type(e))], ) raise if self._has_window_functions: sub_select = SUB_SELECT_EVENTS_WINDOW else: self._cursor_run(cursor, "set @row_num = 0") self._cursor_run(cursor, "set @current_digest = ''") sub_select = SUB_SELECT_EVENTS_NUMBERED self._cursor_run( cursor, "set @startup_time_s = {}".format( STARTUP_TIME_SUBQUERY.format( global_status_table=self._global_status_table)), ) self._cursor_run( cursor, EVENTS_STATEMENTS_QUERY.format( statements_numbered=sub_select.format( statements_table=self._events_statements_temp_table)), params, ) rows = cursor.fetchall() self._cursor_run(cursor, drop_temp_table_query) tags = self._tags + [ "events_statements_table:{}".format(events_statements_table) ] self._check.histogram("dd.mysql.get_new_events_statements.time", (time.time() - start) * 1000, tags=tags) self._check.histogram("dd.mysql.get_new_events_statements.rows", len(rows), tags=tags) self._log.debug("Read %s rows from %s", len(rows), events_statements_table) return rows def _filter_valid_statement_rows(self, rows): num_sent = 0 num_truncated = 0 for row in rows: if not row or not all(row): self._log.debug( 'Row was unexpectedly truncated or the events_statements table is not enabled' ) continue sql_text = row['sql_text'] if not sql_text: continue # The SQL_TEXT column will store 1024 chars by default. Plans cannot be captured on truncated # queries, so the `performance_schema_max_sql_text_length` variable must be raised. if sql_text[-3:] == '...': num_truncated += 1 continue yield row # only save the checkpoint for rows that we have successfully processed # else rows that we ignore can push the checkpoint forward causing us to miss some on the next run if row['timer_start'] > self._checkpoint: self._checkpoint = row['timer_start'] num_sent += 1 if num_truncated > 0: self._log.warning( 'Unable to collect %d/%d statement samples due to truncated SQL text. Consider raising ' '`performance_schema_max_sql_text_length` to capture these queries.', num_truncated, num_truncated + num_sent, ) self._check.count("dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:truncated-sql-text"]) def _collect_plan_for_statement(self, row): # Plans have several important signatures to tag events with: # - `plan_signature` - hash computed from the normalized JSON plan to group identical plan trees # - `resource_hash` - hash computed off the raw sql text to match apm resources # - `query_signature` - hash computed from the digest text to match query metrics try: obfuscated_statement = datadog_agent.obfuscate_sql(row['sql_text']) obfuscated_digest_text = datadog_agent.obfuscate_sql( row['digest_text']) except Exception: # do not log the raw sql_text to avoid leaking sensitive data into logs. digest_text is safe as parameters # are obfuscated by the database self._log.debug("Failed to obfuscate statement: %s", row['digest_text']) self._check.count("dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:sql-obfuscate"]) return None apm_resource_hash = compute_sql_signature(obfuscated_statement) query_signature = compute_sql_signature(obfuscated_digest_text) query_cache_key = (row['current_schema'], query_signature) if query_cache_key in self._explained_statements_cache: return None self._explained_statements_cache[query_cache_key] = True plan = None with closing(self._get_db_connection().cursor()) as cursor: try: plan = self._explain_statement(cursor, row['sql_text'], row['current_schema'], obfuscated_statement) except Exception as e: self._check.count("dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:explain-{}".format(type(e))]) self._log.exception("Failed to explain statement: %s", obfuscated_statement) normalized_plan, obfuscated_plan, plan_signature, plan_cost = None, None, None, None if plan: normalized_plan = datadog_agent.obfuscate_sql_exec_plan( plan, normalize=True) if plan else None obfuscated_plan = datadog_agent.obfuscate_sql_exec_plan(plan) plan_signature = compute_exec_plan_signature(normalized_plan) plan_cost = self._parse_execution_plan_cost(plan) query_plan_cache_key = (query_cache_key, plan_signature) if query_plan_cache_key not in self._seen_samples_cache: self._seen_samples_cache[query_plan_cache_key] = True return { "timestamp": row["timer_end_time_s"] * 1000, "host": self._db_hostname, "service": self._service, "ddsource": "mysql", "ddtags": self._tags_str, "duration": row['timer_wait_ns'], "network": { "client": { "ip": row.get('processlist_host', None), } }, "db": { "instance": row['current_schema'], "plan": { "definition": obfuscated_plan, "cost": plan_cost, "signature": plan_signature }, "query_signature": query_signature, "resource_hash": apm_resource_hash, "statement": obfuscated_statement, }, 'mysql': { k: v for k, v in row.items() if k not in EVENTS_STATEMENTS_SAMPLE_EXCLUDE_KEYS }, } def _collect_plans_for_statements(self, rows): for row in rows: try: event = self._collect_plan_for_statement(row) if event: yield event except Exception: self._log.debug("Failed to collect plan for statement", exc_info=1) def _get_enabled_performance_schema_consumers(self): """ Returns the list of available performance schema consumers I.e. (events_statements_current, events_statements_history) :return: """ with closing(self._get_db_connection().cursor()) as cursor: self._cursor_run(cursor, ENABLED_STATEMENTS_CONSUMERS_QUERY) return set([r[0] for r in cursor.fetchall()]) def _enable_events_statements_consumers(self): """ Enable events statements consumers :return: """ try: with closing(self._get_db_connection().cursor()) as cursor: self._cursor_run( cursor, 'CALL {}()'.format( self._events_statements_enable_procedure)) except pymysql.err.DatabaseError as e: self._log.debug( "failed to enable events_statements consumers using procedure=%s: %s", self._events_statements_enable_procedure, e, ) def _get_sample_collection_strategy(self): """ Decides on the plan collection strategy: - which events_statement_history-* table are we using - how long should the rate and time limits be :return: (table, rate_limit) """ cached_strategy = self._collection_strategy_cache.get( "plan_collection_strategy") if cached_strategy: self._log.debug("Using cached plan_collection_strategy: %s", cached_strategy) return cached_strategy enabled_consumers = self._get_enabled_performance_schema_consumers() if len(enabled_consumers) < 3: self._enable_events_statements_consumers() enabled_consumers = self._get_enabled_performance_schema_consumers( ) if not enabled_consumers: self._log.warning( "Cannot collect statement samples as there are no enabled performance_schema.events_statements_* " "consumers. Enable performance_schema and at least one events_statements consumer in order to collect " "statement samples.") self._check.count( "dd.mysql.statement_samples.error", 1, tags=self._tags + ["error:no-enabled-events-statements-consumers"], ) return None, None self._log.debug( "Found enabled performance_schema statements consumers: %s", enabled_consumers) events_statements_table = None for table in self._preferred_events_statements_tables: if table not in enabled_consumers: continue rows = self._get_new_events_statements(table, 1) if not rows: self._log.debug( "No statements found in %s table. checking next one.", table) continue events_statements_table = table break if not events_statements_table: self._log.warning( "Cannot collect statement samples as all enabled events_statements_consumers %s are empty.", enabled_consumers, ) return None, None rate_limit = self._collections_per_second if rate_limit < 0: rate_limit = DEFAULT_EVENTS_STATEMENTS_COLLECTIONS_PER_SECOND[ events_statements_table] # cache only successful strategies # should be short enough that we'll reflect updates relatively quickly # i.e., an aurora replica becomes a master (or vice versa). strategy = (events_statements_table, rate_limit) self._log.debug( "Chose plan collection strategy: events_statements_table=%s, collections_per_second=%s", events_statements_table, rate_limit, ) self._collection_strategy_cache["plan_collection_strategy"] = strategy return strategy def _collect_statement_samples(self): self._rate_limiter.sleep() events_statements_table, rate_limit = self._get_sample_collection_strategy( ) if not events_statements_table: return if self._rate_limiter.rate_limit_s != rate_limit: self._rate_limiter = ConstantRateLimiter(rate_limit) start_time = time.time() tags = self._tags + [ "events_statements_table:{}".format(events_statements_table) ] rows = self._get_new_events_statements( events_statements_table, self._events_statements_row_limit) rows = self._filter_valid_statement_rows(rows) events = self._collect_plans_for_statements(rows) submitted_count, failed_count = self._statement_samples_client.submit_events( events) self._check.count("dd.mysql.statement_samples.error", failed_count, tags=self._tags + ["error:submit-events"]) self._check.histogram("dd.mysql.collect_statement_samples.time", (time.time() - start_time) * 1000, tags=tags) self._check.count( "dd.mysql.collect_statement_samples.events_submitted.count", submitted_count, tags=tags) self._check.gauge( "dd.mysql.collect_statement_samples.seen_samples_cache.len", len(self._seen_samples_cache), tags=tags) self._check.gauge( "dd.mysql.collect_statement_samples.explained_statements_cache.len", len(self._explained_statements_cache), tags=tags, ) self._check.gauge( "dd.mysql.collect_statement_samples.collection_strategy_cache.len", len(self._collection_strategy_cache), tags=tags, ) def _explain_statement(self, cursor, statement, schema, obfuscated_statement): """ Tries the available methods used to explain a statement for the given schema. If a non-retryable error occurs (such as a permissions error), then statements executed under the schema will be disallowed in future attempts. """ start_time = time.time() strategy_cache_key = 'explain_strategy:%s' % schema explain_strategy_error = 'ERROR' tags = self._tags + ["schema:{}".format(schema)] self._log.debug('explaining statement. schema=%s, statement="%s"', schema, statement) if not self._can_explain(obfuscated_statement): self._log.debug('Skipping statement which cannot be explained: %s', obfuscated_statement) return None if self._collection_strategy_cache.get( strategy_cache_key) == explain_strategy_error: self._log.debug( 'Skipping statement due to cached collection failure: %s', obfuscated_statement) return None try: # If there was a default schema when this query was run, then switch to it before trying to collect # the execution plan. This is necessary when the statement uses non-fully qualified tables # e.g. `select * from mytable` instead of `select * from myschema.mytable` if schema: self._cursor_run(cursor, 'USE `{}`'.format(schema)) except pymysql.err.DatabaseError as e: if len(e.args) != 2: raise if e.args[0] in PYMYSQL_NON_RETRYABLE_ERRORS: self._collection_strategy_cache[ strategy_cache_key] = explain_strategy_error self._check.count("dd.mysql.statement_samples.error", 1, tags=tags + ["error:explain-use-schema-{}".format(type(e))]) self._log.debug( 'Failed to collect execution plan because schema could not be accessed. error=%s, schema=%s, ' 'statement="%s"', e.args, schema, obfuscated_statement, ) return None # Use a cached strategy for the schema, if any, or try each strategy to collect plans strategies = list(self._preferred_explain_strategies) cached = self._collection_strategy_cache.get(strategy_cache_key) if cached: strategies.remove(cached) strategies.insert(0, cached) for strategy in strategies: try: if not schema and strategy == "PROCEDURE": self._log.debug( 'skipping PROCEDURE strategy as there is no default schema for this statement="%s"', obfuscated_statement, ) continue plan = self._explain_strategies[strategy](cursor, statement, obfuscated_statement) if plan: self._collection_strategy_cache[ strategy_cache_key] = strategy self._log.debug( 'Successfully collected execution plan. strategy=%s, schema=%s, statement="%s"', strategy, schema, obfuscated_statement, ) self._check.histogram( "dd.mysql.run_explain.time", (time.time() - start_time) * 1000, tags=self._tags + ["strategy:{}".format(strategy)], ) return plan except pymysql.err.DatabaseError as e: if len(e.args) != 2: raise # we don't cache failed plan collection failures for specific queries because some queries in a schema # can fail while others succeed. The failed collection will be cached for the specific query # so we won't try to explain it again for the cache duration there. self._check.count( "dd.mysql.statement_samples.error", 1, tags=tags + ["error:explain-attempt-{}-{}".format(strategy, type(e))], ) self._log.debug( 'Failed to collect execution plan. error=%s, strategy=%s, schema=%s, statement="%s"', e.args, strategy, schema, obfuscated_statement, ) continue return None def _run_explain(self, cursor, statement, obfuscated_statement): """ Run the explain using the EXPLAIN statement """ self._log.debug("running query [EXPLAIN FORMAT=json %s]", obfuscated_statement) cursor.execute('EXPLAIN FORMAT=json {}'.format(statement)) return cursor.fetchone()[0] def _run_explain_procedure(self, cursor, statement, obfuscated_statement): """ Run the explain by calling the stored procedure if available. """ self._cursor_run(cursor, 'CALL {}(%s)'.format(self._explain_procedure), statement, obfuscated_statement) return cursor.fetchone()[0] def _run_fully_qualified_explain_procedure(self, cursor, statement, obfuscated_statement): """ Run the explain by calling the fully qualified stored procedure if available. """ self._cursor_run( cursor, 'CALL {}(%s)'.format(self._fully_qualified_explain_procedure), statement, obfuscated_statement) return cursor.fetchone()[0] @staticmethod def _can_explain(obfuscated_statement): return obfuscated_statement.split( ' ', 1)[0].lower() in VALID_EXPLAIN_STATEMENTS @staticmethod def _parse_execution_plan_cost(execution_plan): """ Parses the total cost from the execution plan, if set. If not set, returns cost of 0. """ cost = json.loads(execution_plan).get('query_block', {}).get( 'cost_info', {}).get('query_cost', 0.0) return float(cost or 0.0)
class DnsOverHttpsResolver(object): def __init__(self, loop=None, semaphore=None, public_ip=None, proxy_ip=None, google_ip=None, domain_set=None, socks_proxy=None, cache_size=None, cache_ttl=None): self.loop = loop self.semaphore = semaphore self.public_ip = public_ip self.proxy_ip = proxy_ip self.domain_set = domain_set self.google_ip = google_ip self.socks_proxy = socks_proxy self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl) if self.socks_proxy: self.base_url = 'https://{}/resolve?'.format('dns.google.com') else: self.base_url = 'https://{}/resolve?'.format(google_ip) self.headers = {'Host': 'dns.google.com'} self.transport = None def match_client_ip(self, query_name): if self.public_ip == self.proxy_ip: return self.public_ip else: if any(str(query_name)[:-1].endswith(x) for x in self.domain_set): return self.proxy_ip return self.public_ip @retry(exceptions=aiohttp.client_exceptions.ClientConnectionError, tries=3) async def http_fetch(self, url): with await self.semaphore: if self.socks_proxy: connector = ProxyConnector(remote_resolve=True) request_class = ProxyClientRequest else: connector = GoogleDirectConnector(self.google_ip) request_class = aiohttp.client_reqrep.ClientRequest with aiohttp.ClientSession(loop=self.loop, connector=connector, request_class=request_class) as session: async with session.get(url, proxy=self.socks_proxy, headers=self.headers) as resp: result = await resp.read() return result @staticmethod def build_answer_from_json(request, json_item): ans = request.reply() ans.header.rcode = json_item['Status'] if 'Answer' in json_item.keys(): for answer in json_item['Answer']: q_type = QTYPE[answer['type']] q_type_class = globals()[q_type] ans.add_answer( RR(rname=answer['name'], rtype=answer['type'], ttl=answer['TTL'], rdata=q_type_class(answer['data']))) elif 'Authority' in json_item.keys(): for auth in json_item['Authority']: q_type = QTYPE[auth['type']] q_type_class = globals()[q_type] ans.add_auth( RR(rname=auth['name'], rtype=auth['type'], ttl=auth['TTL'], rdata=q_type_class(auth['data']))) packet_resp = ans.pack() return packet_resp @staticmethod def build_serv_fail(request): ans = request.reply() ans.header.rcode = 2 packet_resp = ans.pack() return packet_resp async def query_and_answer(self, transport, request, client): packet_resp = await self.query_request(request) self.send_response(transport, packet_resp, client) async def query_request(self, request): logging.debug('Request name: {}, Request type:{}.'.format( request.q.qname, QTYPE[request.q.qtype])) cache_key = str(request.q.qname) + '_' + str(request.q.qtype) cached_item = self.cache.get(cache_key) if cached_item: logging.debug('Cached:{}'.format(cache_key)) packet_resp = self.build_answer_from_json(request, cached_item) else: logging.debug('Fetch:{}'.format(request.q.qname)) client_ip = self.match_client_ip(request.q.qname) + '/24' url = self.base_url + urlencode({ 'name': request.q.qname, 'type': request.q.qtype, 'edns_client_subnet': client_ip }) logging.debug('Querying URL:{}.'.format(url)) http_resp = await self.http_fetch(url) if http_resp: json_resp = json.loads(http_resp) if 'Answer' in json_resp.keys(): self.cache[str(request.q.qname) + '_' + str(request.q.qtype)] = json_resp packet_resp = self.build_answer_from_json(request, json_resp) else: logging.debug('Return Serv Fail!') packet_resp = self.build_serv_fail(request) return packet_resp @staticmethod def send_response(transport, packet_resp, client): try: transport.sendto(packet_resp, client) except AttributeError: packet_resp = struct.pack(">H", packet_resp.__len__()) + packet_resp transport.write(packet_resp) transport.close()