class SearchIndex: def __init__(self, index_prefix: str, search_timeout=3.0): self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None self.sync_client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2 ** 15) self.short_id_cache = LRUCache(2 ** 17) # never invalidated, since short ids are forever self.search_cache = LRUCache(2 ** 17) self.resolution_cache = LRUCache(2 ** 17) async def start(self): if self.sync_client: return self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(timeout=self.search_timeout) while True: try: await self.sync_client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) return res.get('acknowledged', False) def stop(self): clients = [self.sync_client, self.search_client] self.sync_client, self.search_client = None, None return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) def delete_index(self): return self.sync_client.indices.delete(self.index, ignore_unavailable=True) async def _consume_claim_producer(self, claim_producer): count = 0 for op, doc in claim_producer: if op == 'delete': yield {'_index': self.index, '_op_type': 'delete', '_id': doc} else: yield extract_doc(doc, self.index) count += 1 if count % 100 == 0: self.logger.info("Indexing in progress, %d claims.", count) self.logger.info("Indexing done for %d claims.", count) async def claim_consumer(self, claim_producer): touched = set() async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer), raise_on_error=False): if not ok: self.logger.warning("indexing failed for an item: %s", item) else: item = item.popitem()[1] touched.add(item['_id']) await self.sync_client.indices.refresh(self.index) self.logger.info("Indexing done.") def update_filter_query(self, censor_type, blockdict, channels=False): blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()} if channels: update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") else: update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") key = 'channel_id' if channels else 'claim_id' update['script'] = { "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]", "lang": "painless", "params": blockdict } return update async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): if filtered_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4) await self.sync_client.indices.refresh(self.index) if filtered_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) self.search_cache.clear() self.claim_cache.clear() self.resolution_cache.clear() async def session_query(self, query_name, kwargs): offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0 total_referenced = [] if query_name == 'resolve': total_referenced, response, censor = await self.resolve(*kwargs) else: cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) if cache_item.result is not None: return cache_item.result async with cache_item.lock: if cache_item.result: return cache_item.result censor = Censor(Censor.SEARCH) if kwargs.get('no_totals'): response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) else: response, offset, total = await self.search(**kwargs) censor.apply(response) total_referenced.extend(response) if censor.censored: response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) result = Outputs.to_base64( response, await self._get_referenced_rows(total_referenced), offset, total, censor ) cache_item.result = result return result return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor) async def resolve(self, *urls): censor = Censor(Censor.RESOLVE) results = [await self.resolve_url(url) for url in urls] # just heat the cache await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results)) results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)] censored = [ result if not isinstance(result, dict) or not censor.censor(result) else ResolveCensoredError(url, result['censoring_channel_hash']) for url, result in zip(urls, results) ] return results, censored, censor def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]): cached = self.claim_cache.get(resolution) return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url)) async def get_many(self, *claim_ids): await self.populate_claim_cache(*claim_ids) return filter(None, map(self.claim_cache.get, claim_ids)) async def populate_claim_cache(self, *claim_ids): missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None] if missing: results = await self.search_client.mget( index=self.index, body={"ids": missing} ) for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) async def full_id_from_short_id(self, name, short_id, channel_id=None): key = (channel_id or '') + name + short_id if key not in self.short_id_cache: query = {'name': name, 'claim_id': short_id} if channel_id: query['channel_id'] = channel_id query['order_by'] = ['^channel_join'] query['signature_valid'] = True else: query['order_by'] = '^creation_height' result, _, _ = await self.search(**query, limit=1) if len(result) == 1: result = result[0]['claim_id'] self.short_id_cache[key] = result return self.short_id_cache.get(key, None) async def search(self, **kwargs): if 'channel' in kwargs: kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel')) if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str): return [], 0, 0 try: result = (await self.search_client.search( expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000 ))['hits'] except NotFoundError: return [], 0, 0 return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: self.resolution_cache[raw_url] = await self._resolve_url(raw_url) return self.resolution_cache[raw_url] async def _resolve_url(self, raw_url): try: url = URL.parse(raw_url) except ValueError as e: return e stream = LookupError(f'Could not find claim at "{raw_url}".') channel_id = await self.resolve_channel_id(url) if isinstance(channel_id, LookupError): return channel_id stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream if url.has_stream: return StreamResolution(stream) else: return ChannelResolution(channel_id) async def resolve_channel_id(self, url: URL): if not url.has_channel: return if url.channel.is_fullid: return url.channel.claim_id if url.channel.is_shortid: channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id) if not channel_id: return LookupError(f'Could not find channel in "{url}".') return channel_id query = url.channel.to_dict() if set(query) == {'name'}: query['is_controlling'] = True else: query['order_by'] = ['^creation_height'] matches, _, _ = await self.search(**query, limit=1) if matches: channel_id = matches[0]['claim_id'] else: return LookupError(f'Could not find channel in "{url}".') return channel_id async def resolve_stream(self, url: URL, channel_id: str = None): if not url.has_stream: return None if url.has_channel and channel_id is None: return None query = url.stream.to_dict() if url.stream.claim_id is not None: if url.stream.is_fullid: claim_id = url.stream.claim_id else: claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) return claim_id if channel_id is not None: if set(query) == {'name'}: # temporarily emulate is_controlling for claims in channel query['order_by'] = ['effective_amount', '^height'] else: query['order_by'] = ['^channel_join'] query['channel_id'] = channel_id query['signature_valid'] = True elif set(query) == {'name'}: query['is_controlling'] = True matches, _, _ = await self.search(**query, limit=1) if matches: return matches[0]['claim_id'] async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows))) referenced_ids |= set(map(parse_claim_id, filter(None, (row['censoring_channel_hash'] for row in txo_rows)))) referenced_txos = [] if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos))) if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) return referenced_txos
class SearchIndex: VERSION = 1 def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200): self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None self.sync_client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2**15) self.search_cache = LRUCache(2**17) self._elastic_host = elastic_host self._elastic_port = elastic_port async def get_index_version(self) -> int: try: template = await self.sync_client.indices.get_template(self.index) return template[self.index]['version'] except NotFoundError: return 0 async def set_index_version(self, version): await self.sync_client.indices.put_template(self.index, body={ 'version': version, 'index_patterns': ['ignored'] }, ignore=400) async def start(self) -> bool: if self.sync_client: return False hosts = [{'host': self._elastic_host, 'port': self._elastic_port}] self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout) while True: try: await self.sync_client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning( "Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) acked = res.get('acknowledged', False) if acked: await self.set_index_version(self.VERSION) return acked index_version = await self.get_index_version() if index_version != self.VERSION: self.logger.error( "es search index has an incompatible version: %s vs %s", index_version, self.VERSION) raise IndexVersionMismatch(index_version, self.VERSION) await self.sync_client.indices.refresh(self.index) return acked def stop(self): clients = [self.sync_client, self.search_client] self.sync_client, self.search_client = None, None return asyncio.ensure_future( asyncio.gather(*(client.close() for client in clients))) def delete_index(self): return self.sync_client.indices.delete(self.index, ignore_unavailable=True) async def _consume_claim_producer(self, claim_producer): count = 0 async for op, doc in claim_producer: if op == 'delete': yield {'_index': self.index, '_op_type': 'delete', '_id': doc} else: yield { 'doc': { key: value for key, value in doc.items() if key in ALL_FIELDS }, '_id': doc['claim_id'], '_index': self.index, '_op_type': 'update', 'doc_as_upsert': True } count += 1 if count % 100 == 0: self.logger.info("Indexing in progress, %d claims.", count) if count: self.logger.info("Indexing done for %d claims.", count) else: self.logger.debug("Indexing done for %d claims.", count) async def claim_consumer(self, claim_producer): touched = set() async for ok, item in async_streaming_bulk( self.sync_client, self._consume_claim_producer(claim_producer), raise_on_error=False): if not ok: self.logger.warning("indexing failed for an item: %s", item) else: item = item.popitem()[1] touched.add(item['_id']) await self.sync_client.indices.refresh(self.index) self.logger.debug("Indexing done.") def update_filter_query(self, censor_type, blockdict, channels=False): blockdict = { blocked.hex(): blocker.hex() for blocked, blocker in blockdict.items() } if channels: update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") else: update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") key = 'channel_id' if channels else 'claim_id' update['script'] = { "source": f"ctx._source.censor_type={censor_type}; " f"ctx._source.censoring_channel_id=params[ctx._source.{key}];", "lang": "painless", "params": blockdict } return update async def update_trending_score(self, params): update_trending_score_script = """ double softenLBC(double lbc) { return (Math.pow(lbc, 1.0 / 3.0)); } double logsumexp(double x, double y) { double top; if(x > y) top = x; else top = y; double result = top + Math.log(Math.exp(x-top) + Math.exp(y-top)); return(result); } double logdiffexp(double big, double small) { return big + Math.log(1.0 - Math.exp(small - big)); } double squash(double x) { if(x < 0.0) return -Math.log(1.0 - x); else return Math.log(x + 1.0); } double unsquash(double x) { if(x < 0.0) return 1.0 - Math.exp(-x); else return Math.exp(x) - 1.0; } double log_to_squash(double x) { return logsumexp(x, 0.0); } double squash_to_log(double x) { //assert x > 0.0; return logdiffexp(x, 0.0); } double squashed_add(double x, double y) { // squash(unsquash(x) + unsquash(y)) but avoiding overflow. // Cases where the signs are the same if (x < 0.0 && y < 0.0) return -logsumexp(-x, logdiffexp(-y, 0.0)); if (x >= 0.0 && y >= 0.0) return logsumexp(x, logdiffexp(y, 0.0)); // Where the signs differ if (x >= 0.0 && y < 0.0) if (Math.abs(x) >= Math.abs(y)) return logsumexp(0.0, logdiffexp(x, -y)); else return -logsumexp(0.0, logdiffexp(-y, x)); if (x < 0.0 && y >= 0.0) { // Addition is commutative, hooray for new math return squashed_add(y, x); } return 0.0; } double squashed_multiply(double x, double y) { // squash(unsquash(x)*unsquash(y)) but avoiding overflow. int sign; if(x*y >= 0.0) sign = 1; else sign = -1; return sign*logsumexp(squash_to_log(Math.abs(x)) + squash_to_log(Math.abs(y)), 0.0); } // Squashed inflated units double inflateUnits(int height) { double timescale = 576.0; // Half life of 400 = e-folding time of a day // by coincidence, so may as well go with it return log_to_squash(height / timescale); } double spikePower(double newAmount) { if (newAmount < 50.0) { return(0.5); } else if (newAmount < 85.0) { return(newAmount / 100.0); } else { return(0.85); } } double spikeMass(double oldAmount, double newAmount) { double softenedChange = softenLBC(Math.abs(newAmount - oldAmount)); double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount)); double power = spikePower(newAmount); if (oldAmount > newAmount) { -1.0 * Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power) } else { Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power) } } for (i in params.src.changes) { double units = inflateUnits(i.height); if (ctx._source.trending_score == null) { ctx._source.trending_score = 0.0; } double bigSpike = squashed_multiply(units, squash(spikeMass(i.prev_amount, i.new_amount))); ctx._source.trending_score = squashed_add(ctx._source.trending_score, bigSpike); } """ start = time.perf_counter() def producer(): for claim_id, claim_updates in params.items(): yield { '_id': claim_id, '_index': self.index, '_op_type': 'update', 'script': { 'lang': 'painless', 'source': update_trending_score_script, 'params': { 'src': { 'changes': [{ 'height': p.height, 'prev_amount': p.prev_amount / 1E8, 'new_amount': p.new_amount / 1E8, } for p in claim_updates] } } }, } if not params: return async for ok, item in async_streaming_bulk(self.sync_client, producer(), raise_on_error=False): if not ok: self.logger.warning("updating trending failed for an item: %s", item) await self.sync_client.indices.refresh(self.index) self.logger.info("updated trending scores in %ims", int((time.perf_counter() - start) * 1000)) async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): if filtered_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4) await self.sync_client.indices.refresh(self.index) if filtered_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) self.clear_caches() def clear_caches(self): self.search_cache.clear() self.claim_cache.clear() async def cached_search(self, kwargs): total_referenced = [] cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) if cache_item.result is not None: return cache_item.result async with cache_item.lock: if cache_item.result: return cache_item.result censor = Censor(Censor.SEARCH) if kwargs.get('no_totals'): response, offset, total = await self.search( **kwargs, censor_type=Censor.NOT_CENSORED) else: response, offset, total = await self.search(**kwargs) censor.apply(response) total_referenced.extend(response) if censor.censored: response, _, _ = await self.search( **kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) response = [ ResolveResult(name=r['claim_name'], normalized_name=r['normalized_name'], claim_hash=r['claim_hash'], tx_num=r['tx_num'], position=r['tx_nout'], tx_hash=r['tx_hash'], height=r['height'], amount=r['amount'], short_url=r['short_url'], is_controlling=r['is_controlling'], canonical_url=r['canonical_url'], creation_height=r['creation_height'], activation_height=r['activation_height'], expiration_height=r['expiration_height'], effective_amount=r['effective_amount'], support_amount=r['support_amount'], last_takeover_height=r['last_take_over_height'], claims_in_channel=r['claims_in_channel'], channel_hash=r['channel_hash'], reposted_claim_hash=r['reposted_claim_hash'], reposted=r['reposted'], signature_valid=r['signature_valid']) for r in response ] extra = [ ResolveResult(name=r['claim_name'], normalized_name=r['normalized_name'], claim_hash=r['claim_hash'], tx_num=r['tx_num'], position=r['tx_nout'], tx_hash=r['tx_hash'], height=r['height'], amount=r['amount'], short_url=r['short_url'], is_controlling=r['is_controlling'], canonical_url=r['canonical_url'], creation_height=r['creation_height'], activation_height=r['activation_height'], expiration_height=r['expiration_height'], effective_amount=r['effective_amount'], support_amount=r['support_amount'], last_takeover_height=r['last_take_over_height'], claims_in_channel=r['claims_in_channel'], channel_hash=r['channel_hash'], reposted_claim_hash=r['reposted_claim_hash'], reposted=r['reposted'], signature_valid=r['signature_valid']) for r in await self._get_referenced_rows(total_referenced) ] result = Outputs.to_base64(response, extra, offset, total, censor) cache_item.result = result return result async def get_many(self, *claim_ids): await self.populate_claim_cache(*claim_ids) return filter(None, map(self.claim_cache.get, claim_ids)) async def populate_claim_cache(self, *claim_ids): missing = [ claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None ] if missing: results = await self.search_client.mget(index=self.index, body={"ids": missing}) for result in expand_result( filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) async def search(self, **kwargs): try: return await self.search_ahead(**kwargs) except NotFoundError: return [], 0, 0 # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 remove_duplicates = kwargs.pop('remove_duplicates', False) page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 cache_item = ResultCacheItem.from_cache( f"ahead{per_channel_per_page}{kwargs}", self.search_cache) if cache_item.result is not None: reordered_hits = cache_item.result else: async with cache_item.lock: if cache_item.result: reordered_hits = cache_item.result else: query = expand_query(**kwargs) search_hits = deque( (await self.search_client.search(query, index=self.index, track_total_hits=False, _source_includes=[ '_id', 'channel_id', 'reposted_claim_id', 'creation_height' ]))['hits']['hits']) if remove_duplicates: search_hits = self.__remove_duplicates(search_hits) if per_channel_per_page > 0: reordered_hits = self.__search_ahead( search_hits, page_size, per_channel_per_page) else: reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] cache_item.result = reordered_hits result = list(await self.get_many( *(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) return result, 0, len(reordered_hits) def __remove_duplicates(self, search_hits: deque) -> deque: known_ids = { } # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original dropped = set() for hit in search_hits: hit_height, hit_id = hit['_source']['creation_height'], hit[ '_source']['reposted_claim_id'] or hit['_id'] if hit_id not in known_ids: known_ids[hit_id] = (hit_height, hit['_id']) else: previous_height, previous_id = known_ids[hit_id] if hit_height < previous_height: known_ids[hit_id] = (hit_height, hit['_id']) dropped.add(previous_id) else: dropped.add(hit['_id']) return deque(hit for hit in search_hits if hit['_id'] not in dropped) def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): reordered_hits = [] channel_counters = Counter() next_page_hits_maybe_check_later = deque() while search_hits or next_page_hits_maybe_check_later: if reordered_hits and len(reordered_hits) % page_size == 0: channel_counters.clear() elif not reordered_hits: pass else: break # means last page was incomplete and we are left with bad replacements for _ in range(len(next_page_hits_maybe_check_later)): claim_id, channel_id = next_page_hits_maybe_check_later.popleft( ) if per_channel_per_page > 0 and channel_counters[ channel_id] < per_channel_per_page: reordered_hits.append((claim_id, channel_id)) channel_counters[channel_id] += 1 else: next_page_hits_maybe_check_later.append( (claim_id, channel_id)) while search_hits: hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source'][ 'channel_id'] if hit_channel_id is None or per_channel_per_page <= 0: reordered_hits.append((hit_id, hit_channel_id)) elif channel_counters[hit_channel_id] < per_channel_per_page: reordered_hits.append((hit_id, hit_channel_id)) channel_counters[hit_channel_id] += 1 if len(reordered_hits) % page_size == 0: break else: next_page_hits_maybe_check_later.append( (hit_id, hit_channel_id)) return reordered_hits async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] referenced_ids = set( filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) referenced_ids |= set( filter(None, (row['channel_id'] for row in txo_rows))) referenced_ids |= set( filter(None, (row['censoring_channel_id'] for row in txo_rows))) referenced_txos = [] if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) referenced_ids = set( filter(None, (row['channel_id'] for row in referenced_txos))) if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) return referenced_txos
class SearchIndex: VERSION = 1 def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200): self.search_timeout = search_timeout self.sync_timeout = 600 # wont hit that 99% of the time, but can hit on a fresh import self.search_client: Optional[AsyncElasticsearch] = None self.sync_client: Optional[AsyncElasticsearch] = None self.index = index_prefix + 'claims' self.logger = class_logger(__name__, self.__class__.__name__) self.claim_cache = LRUCache(2 ** 15) self.short_id_cache = LRUCache(2 ** 17) self.search_cache = LRUCache(2 ** 17) self.resolution_cache = LRUCache(2 ** 17) self._elastic_host = elastic_host self._elastic_port = elastic_port async def get_index_version(self) -> int: try: template = await self.sync_client.indices.get_template(self.index) return template[self.index]['version'] except NotFoundError: return 0 async def set_index_version(self, version): await self.sync_client.indices.put_template( self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400 ) async def start(self) -> bool: if self.sync_client: return False hosts = [{'host': self._elastic_host, 'port': self._elastic_port}] self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout) self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout) while True: try: await self.sync_client.cluster.health(wait_for_status='yellow') break except ConnectionError: self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!") await asyncio.sleep(1) res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400) acked = res.get('acknowledged', False) if acked: await self.set_index_version(self.VERSION) return acked index_version = await self.get_index_version() if index_version != self.VERSION: self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION) raise IndexVersionMismatch(index_version, self.VERSION) return acked def stop(self): clients = [self.sync_client, self.search_client] self.sync_client, self.search_client = None, None return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients))) def delete_index(self): return self.sync_client.indices.delete(self.index, ignore_unavailable=True) async def _consume_claim_producer(self, claim_producer): count = 0 for op, doc in claim_producer: if op == 'delete': yield {'_index': self.index, '_op_type': 'delete', '_id': doc} else: yield extract_doc(doc, self.index) count += 1 if count % 100 == 0: self.logger.info("Indexing in progress, %d claims.", count) self.logger.info("Indexing done for %d claims.", count) async def claim_consumer(self, claim_producer): touched = set() async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer), raise_on_error=False): if not ok: self.logger.warning("indexing failed for an item: %s", item) else: item = item.popitem()[1] touched.add(item['_id']) await self.sync_client.indices.refresh(self.index) self.logger.info("Indexing done.") def update_filter_query(self, censor_type, blockdict, channels=False): blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()} if channels: update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") else: update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}") key = 'channel_id' if channels else 'claim_id' update['script'] = { "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_id=params[ctx._source.{key}]", "lang": "painless", "params": blockdict } return update async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels): if filtered_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4) await self.sync_client.indices.refresh(self.index) if filtered_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_streams: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4) await self.sync_client.indices.refresh(self.index) if blocked_channels: await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4) await self.sync_client.indices.refresh(self.index) await self.sync_client.update_by_query( self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4) await self.sync_client.indices.refresh(self.index) self.clear_caches() def clear_caches(self): self.search_cache.clear() self.short_id_cache.clear() self.claim_cache.clear() self.resolution_cache.clear() async def session_query(self, query_name, kwargs): offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0 total_referenced = [] if query_name == 'resolve': total_referenced, response, censor = await self.resolve(*kwargs) else: cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache) if cache_item.result is not None: return cache_item.result async with cache_item.lock: if cache_item.result: return cache_item.result censor = Censor(Censor.SEARCH) if kwargs.get('no_totals'): response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) else: response, offset, total = await self.search(**kwargs) censor.apply(response) total_referenced.extend(response) if censor.censored: response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) result = Outputs.to_base64( response, await self._get_referenced_rows(total_referenced), offset, total, censor ) cache_item.result = result return result return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor) async def resolve(self, *urls): censor = Censor(Censor.RESOLVE) results = [await self.resolve_url(url) for url in urls] # just heat the cache await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results)) results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)] censored = [ result if not isinstance(result, dict) or not censor.censor(result) else ResolveCensoredError(url, result['censoring_channel_id']) for url, result in zip(urls, results) ] return results, censored, censor def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]): cached = self.claim_cache.get(resolution) return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url)) async def get_many(self, *claim_ids): await self.populate_claim_cache(*claim_ids) return filter(None, map(self.claim_cache.get, claim_ids)) async def populate_claim_cache(self, *claim_ids): missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None] if missing: results = await self.search_client.mget( index=self.index, body={"ids": missing} ) for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) async def full_id_from_short_id(self, name, short_id, channel_id=None): key = '#'.join((channel_id or '', name, short_id)) if key not in self.short_id_cache: query = {'name': name, 'claim_id': short_id} if channel_id: query['channel_id'] = channel_id query['order_by'] = ['^channel_join'] query['signature_valid'] = True else: query['order_by'] = '^creation_height' result, _, _ = await self.search(**query, limit=1) if len(result) == 1: result = result[0]['claim_id'] self.short_id_cache[key] = result return self.short_id_cache.get(key, None) async def search(self, **kwargs): if 'channel' in kwargs: kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel')) if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str): return [], 0, 0 try: return await self.search_ahead(**kwargs) except NotFoundError: return [], 0, 0 return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0) async def search_ahead(self, **kwargs): # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0 remove_duplicates = kwargs.pop('remove_duplicates', False) page_size = kwargs.pop('limit', 10) offset = kwargs.pop('offset', 0) kwargs['limit'] = 1000 cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache) if cache_item.result is not None: reordered_hits = cache_item.result else: async with cache_item.lock: if cache_item.result: reordered_hits = cache_item.result else: query = expand_query(**kwargs) search_hits = deque((await self.search_client.search( query, index=self.index, track_total_hits=False, _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height'] ))['hits']['hits']) if remove_duplicates: search_hits = self.__remove_duplicates(search_hits) if per_channel_per_page > 0: reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page) else: reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits] cache_item.result = reordered_hits result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)]))) return result, 0, len(reordered_hits) def __remove_duplicates(self, search_hits: deque) -> deque: known_ids = {} # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original dropped = set() for hit in search_hits: hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id'] if hit_id not in known_ids: known_ids[hit_id] = (hit_height, hit['_id']) else: previous_height, previous_id = known_ids[hit_id] if hit_height < previous_height: known_ids[hit_id] = (hit_height, hit['_id']) dropped.add(previous_id) else: dropped.add(hit['_id']) return deque(hit for hit in search_hits if hit['_id'] not in dropped) def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int): reordered_hits = [] channel_counters = Counter() next_page_hits_maybe_check_later = deque() while search_hits or next_page_hits_maybe_check_later: if reordered_hits and len(reordered_hits) % page_size == 0: channel_counters.clear() elif not reordered_hits: pass else: break # means last page was incomplete and we are left with bad replacements for _ in range(len(next_page_hits_maybe_check_later)): claim_id, channel_id = next_page_hits_maybe_check_later.popleft() if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page: reordered_hits.append((claim_id, channel_id)) channel_counters[channel_id] += 1 else: next_page_hits_maybe_check_later.append((claim_id, channel_id)) while search_hits: hit = search_hits.popleft() hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id'] if hit_channel_id is None or per_channel_per_page <= 0: reordered_hits.append((hit_id, hit_channel_id)) elif channel_counters[hit_channel_id] < per_channel_per_page: reordered_hits.append((hit_id, hit_channel_id)) channel_counters[hit_channel_id] += 1 if len(reordered_hits) % page_size == 0: break else: next_page_hits_maybe_check_later.append((hit_id, hit_channel_id)) return reordered_hits async def resolve_url(self, raw_url): if raw_url not in self.resolution_cache: self.resolution_cache[raw_url] = await self._resolve_url(raw_url) return self.resolution_cache[raw_url] async def _resolve_url(self, raw_url): try: url = URL.parse(raw_url) except ValueError as e: return e stream = LookupError(f'Could not find claim at "{raw_url}".') channel_id = await self.resolve_channel_id(url) if isinstance(channel_id, LookupError): return channel_id stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream if url.has_stream: return StreamResolution(stream) else: return ChannelResolution(channel_id) async def resolve_channel_id(self, url: URL): if not url.has_channel: return if url.channel.is_fullid: return url.channel.claim_id if url.channel.is_shortid: channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id) if not channel_id: return LookupError(f'Could not find channel in "{url}".') return channel_id query = url.channel.to_dict() if set(query) == {'name'}: query['is_controlling'] = True else: query['order_by'] = ['^creation_height'] matches, _, _ = await self.search(**query, limit=1) if matches: channel_id = matches[0]['claim_id'] else: return LookupError(f'Could not find channel in "{url}".') return channel_id async def resolve_stream(self, url: URL, channel_id: str = None): if not url.has_stream: return None if url.has_channel and channel_id is None: return None query = url.stream.to_dict() if url.stream.claim_id is not None: if url.stream.is_fullid: claim_id = url.stream.claim_id else: claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) return claim_id if channel_id is not None: if set(query) == {'name'}: # temporarily emulate is_controlling for claims in channel query['order_by'] = ['effective_amount', '^height'] else: query['order_by'] = ['^channel_join'] query['channel_id'] = channel_id query['signature_valid'] = True elif set(query) == {'name'}: query['is_controlling'] = True matches, _, _ = await self.search(**query, limit=1) if matches: return matches[0]['claim_id'] async def _get_referenced_rows(self, txo_rows: List[dict]): txo_rows = [row for row in txo_rows if isinstance(row, dict)] referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows))) referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows))) referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows))) referenced_txos = [] if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos))) if referenced_ids: referenced_txos.extend(await self.get_many(*referenced_ids)) return referenced_txos