コード例 #1
0
ファイル: search.py プロジェクト: walbens59/lbry-sdk
class SearchIndex:
    def __init__(self, index_prefix: str, search_timeout=3.0):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2 ** 15)
        self.short_id_cache = LRUCache(2 ** 17)  # never invalidated, since short ids are forever
        self.search_cache = LRUCache(2 ** 17)
        self.resolution_cache = LRUCache(2 ** 17)

    async def start(self):
        if self.sync_client:
            return
        self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)
        res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
        return res.get('acknowledged', False)

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index, ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield extract_doc(doc, self.index)
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        self.logger.info("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.search_cache.clear()
        self.claim_cache.clear()
        self.resolution_cache.clear()

    async def session_query(self, query_name, kwargs):
        offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
        total_referenced = []
        if query_name == 'resolve':
            total_referenced, response, censor = await self.resolve(*kwargs)
        else:
            cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
            if cache_item.result is not None:
                return cache_item.result
            async with cache_item.lock:
                if cache_item.result:
                    return cache_item.result
                censor = Censor(Censor.SEARCH)
                if kwargs.get('no_totals'):
                    response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                else:
                    response, offset, total = await self.search(**kwargs)
                censor.apply(response)
                total_referenced.extend(response)
                if censor.censored:
                    response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                    total_referenced.extend(response)
                result = Outputs.to_base64(
                    response, await self._get_referenced_rows(total_referenced), offset, total, censor
                )
                cache_item.result = result
                return result
        return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor)

    async def resolve(self, *urls):
        censor = Censor(Censor.RESOLVE)
        results = [await self.resolve_url(url) for url in urls]
        # just heat the cache
        await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
        results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]

        censored = [
            result if not isinstance(result, dict) or not censor.censor(result)
            else ResolveCensoredError(url, result['censoring_channel_hash'])
            for url, result in zip(urls, results)
        ]
        return results, censored, censor

    def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
        cached = self.claim_cache.get(resolution)
        return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
        if missing:
            results = await self.search_client.mget(
                index=self.index, body={"ids": missing}
            )
            for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def full_id_from_short_id(self, name, short_id, channel_id=None):
        key = (channel_id or '') + name + short_id
        if key not in self.short_id_cache:
            query = {'name': name, 'claim_id': short_id}
            if channel_id:
                query['channel_id'] = channel_id
                query['order_by'] = ['^channel_join']
                query['signature_valid'] = True
            else:
                query['order_by'] = '^creation_height'
            result, _, _ = await self.search(**query, limit=1)
            if len(result) == 1:
                result = result[0]['claim_id']
                self.short_id_cache[key] = result
        return self.short_id_cache.get(key, None)

    async def search(self, **kwargs):
        if 'channel' in kwargs:
            kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel'))
            if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
                return [], 0, 0
        try:
            result = (await self.search_client.search(
                expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000
            ))['hits']
        except NotFoundError:
            return [], 0, 0
        return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def resolve_url(self, raw_url):
        if raw_url not in self.resolution_cache:
            self.resolution_cache[raw_url] = await self._resolve_url(raw_url)
        return self.resolution_cache[raw_url]

    async def _resolve_url(self, raw_url):
        try:
            url = URL.parse(raw_url)
        except ValueError as e:
            return e

        stream = LookupError(f'Could not find claim at "{raw_url}".')

        channel_id = await self.resolve_channel_id(url)
        if isinstance(channel_id, LookupError):
            return channel_id
        stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
        if url.has_stream:
            return StreamResolution(stream)
        else:
            return ChannelResolution(channel_id)

    async def resolve_channel_id(self, url: URL):
        if not url.has_channel:
            return
        if url.channel.is_fullid:
            return url.channel.claim_id
        if url.channel.is_shortid:
            channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id)
            if not channel_id:
                return LookupError(f'Could not find channel in "{url}".')
            return channel_id

        query = url.channel.to_dict()
        if set(query) == {'name'}:
            query['is_controlling'] = True
        else:
            query['order_by'] = ['^creation_height']
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            channel_id = matches[0]['claim_id']
        else:
            return LookupError(f'Could not find channel in "{url}".')
        return channel_id

    async def resolve_stream(self, url: URL, channel_id: str = None):
        if not url.has_stream:
            return None
        if url.has_channel and channel_id is None:
            return None
        query = url.stream.to_dict()
        if url.stream.claim_id is not None:
            if url.stream.is_fullid:
                claim_id = url.stream.claim_id
            else:
                claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
            return claim_id

        if channel_id is not None:
            if set(query) == {'name'}:
                # temporarily emulate is_controlling for claims in channel
                query['order_by'] = ['effective_amount', '^height']
            else:
                query['order_by'] = ['^channel_join']
            query['channel_id'] = channel_id
            query['signature_valid'] = True
        elif set(query) == {'name'}:
            query['is_controlling'] = True
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            return matches[0]['claim_id']

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(map(parse_claim_id, filter(None, (row['censoring_channel_hash'] for row in txo_rows))))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos
コード例 #2
0
ファイル: search.py プロジェクト: nishp77/lbry-sdk
class SearchIndex:
    VERSION = 1

    def __init__(self,
                 index_prefix: str,
                 search_timeout=3.0,
                 elastic_host='localhost',
                 elastic_port=9200):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2**15)
        self.search_cache = LRUCache(2**17)
        self._elastic_host = elastic_host
        self._elastic_port = elastic_port

    async def get_index_version(self) -> int:
        try:
            template = await self.sync_client.indices.get_template(self.index)
            return template[self.index]['version']
        except NotFoundError:
            return 0

    async def set_index_version(self, version):
        await self.sync_client.indices.put_template(self.index,
                                                    body={
                                                        'version':
                                                        version,
                                                        'index_patterns':
                                                        ['ignored']
                                                    },
                                                    ignore=400)

    async def start(self) -> bool:
        if self.sync_client:
            return False
        hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
        self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(hosts,
                                                timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning(
                    "Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)

        res = await self.sync_client.indices.create(self.index,
                                                    INDEX_DEFAULT_SETTINGS,
                                                    ignore=400)
        acked = res.get('acknowledged', False)
        if acked:
            await self.set_index_version(self.VERSION)
            return acked
        index_version = await self.get_index_version()
        if index_version != self.VERSION:
            self.logger.error(
                "es search index has an incompatible version: %s vs %s",
                index_version, self.VERSION)
            raise IndexVersionMismatch(index_version, self.VERSION)
        await self.sync_client.indices.refresh(self.index)
        return acked

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(
            asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index,
                                               ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        async for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield {
                    'doc': {
                        key: value
                        for key, value in doc.items() if key in ALL_FIELDS
                    },
                    '_id': doc['claim_id'],
                    '_index': self.index,
                    '_op_type': 'update',
                    'doc_as_upsert': True
                }
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        if count:
            self.logger.info("Indexing done for %d claims.", count)
        else:
            self.logger.debug("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(
                self.sync_client,
                self._consume_claim_producer(claim_producer),
                raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.debug("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {
            blocked.hex(): blocker.hex()
            for blocked, blocker in blockdict.items()
        }
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()),
                                  censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()),
                                  censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; "
            f"ctx._source.censoring_channel_id=params[ctx._source.{key}];",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def update_trending_score(self, params):
        update_trending_score_script = """
        double softenLBC(double lbc) { return (Math.pow(lbc, 1.0 / 3.0)); }

        double logsumexp(double x, double y)
        {
            double top;
            if(x > y)
                top = x;
            else
                top = y;
            double result = top + Math.log(Math.exp(x-top) + Math.exp(y-top));
            return(result);
        }

        double logdiffexp(double big, double small)
        {
            return big + Math.log(1.0 - Math.exp(small - big));
        }

        double squash(double x)
        {
            if(x < 0.0)
                return -Math.log(1.0 - x);
            else
                return Math.log(x + 1.0);
        }

        double unsquash(double x)
        {
            if(x < 0.0)
                return 1.0 - Math.exp(-x);
            else
                return Math.exp(x) - 1.0;
        }

        double log_to_squash(double x)
        {
            return logsumexp(x, 0.0);
        }

        double squash_to_log(double x)
        {
            //assert x > 0.0;
            return logdiffexp(x, 0.0);
        }

        double squashed_add(double x, double y)
        {
            // squash(unsquash(x) + unsquash(y)) but avoiding overflow.
            // Cases where the signs are the same
            if (x < 0.0 && y < 0.0)
                return -logsumexp(-x, logdiffexp(-y, 0.0));
            if (x >= 0.0 && y >= 0.0)
                return logsumexp(x, logdiffexp(y, 0.0));
            // Where the signs differ
            if (x >= 0.0 && y < 0.0)
                if (Math.abs(x) >= Math.abs(y))
                    return logsumexp(0.0, logdiffexp(x, -y));
                else
                    return -logsumexp(0.0, logdiffexp(-y, x));
            if (x < 0.0 && y >= 0.0)
            {
                // Addition is commutative, hooray for new math
                return squashed_add(y, x);
            }
            return 0.0;
        }

        double squashed_multiply(double x, double y)
        {
            // squash(unsquash(x)*unsquash(y)) but avoiding overflow.
            int sign;
            if(x*y >= 0.0)
                sign = 1;
            else
                sign = -1;
            return sign*logsumexp(squash_to_log(Math.abs(x))
                            + squash_to_log(Math.abs(y)), 0.0);
        }

        // Squashed inflated units
        double inflateUnits(int height) {
            double timescale = 576.0; // Half life of 400 = e-folding time of a day
                                      // by coincidence, so may as well go with it
            return log_to_squash(height / timescale);
        }

        double spikePower(double newAmount) {
            if (newAmount < 50.0) {
                return(0.5);
            } else if (newAmount < 85.0) {
                return(newAmount / 100.0);
            } else {
                return(0.85);
            }
        }

        double spikeMass(double oldAmount, double newAmount) {
            double softenedChange = softenLBC(Math.abs(newAmount - oldAmount));
            double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount));
            double power = spikePower(newAmount);
            if (oldAmount > newAmount) {
                -1.0 * Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
            } else {
                Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
            }
        }
        for (i in params.src.changes) {
            double units = inflateUnits(i.height);
            if (ctx._source.trending_score == null) {
                ctx._source.trending_score = 0.0;
            }
            double bigSpike = squashed_multiply(units, squash(spikeMass(i.prev_amount, i.new_amount)));
            ctx._source.trending_score = squashed_add(ctx._source.trending_score, bigSpike);
        }
        """
        start = time.perf_counter()

        def producer():
            for claim_id, claim_updates in params.items():
                yield {
                    '_id': claim_id,
                    '_index': self.index,
                    '_op_type': 'update',
                    'script': {
                        'lang': 'painless',
                        'source': update_trending_score_script,
                        'params': {
                            'src': {
                                'changes': [{
                                    'height': p.height,
                                    'prev_amount': p.prev_amount / 1E8,
                                    'new_amount': p.new_amount / 1E8,
                                } for p in claim_updates]
                            }
                        }
                    },
                }

        if not params:
            return
        async for ok, item in async_streaming_bulk(self.sync_client,
                                                   producer(),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("updating trending failed for an item: %s",
                                    item)
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("updated trending scores in %ims",
                         int((time.perf_counter() - start) * 1000))

    async def apply_filters(self, blocked_streams, blocked_channels,
                            filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH, filtered_streams),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH,
                                              filtered_channels),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH, filtered_channels,
                                              True),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE, blocked_streams),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE,
                                              blocked_channels),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE, blocked_channels,
                                              True),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.clear_caches()

    def clear_caches(self):
        self.search_cache.clear()
        self.claim_cache.clear()

    async def cached_search(self, kwargs):
        total_referenced = []
        cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
        if cache_item.result is not None:
            return cache_item.result
        async with cache_item.lock:
            if cache_item.result:
                return cache_item.result
            censor = Censor(Censor.SEARCH)
            if kwargs.get('no_totals'):
                response, offset, total = await self.search(
                    **kwargs, censor_type=Censor.NOT_CENSORED)
            else:
                response, offset, total = await self.search(**kwargs)
            censor.apply(response)
            total_referenced.extend(response)

            if censor.censored:
                response, _, _ = await self.search(
                    **kwargs, censor_type=Censor.NOT_CENSORED)
                total_referenced.extend(response)
            response = [
                ResolveResult(name=r['claim_name'],
                              normalized_name=r['normalized_name'],
                              claim_hash=r['claim_hash'],
                              tx_num=r['tx_num'],
                              position=r['tx_nout'],
                              tx_hash=r['tx_hash'],
                              height=r['height'],
                              amount=r['amount'],
                              short_url=r['short_url'],
                              is_controlling=r['is_controlling'],
                              canonical_url=r['canonical_url'],
                              creation_height=r['creation_height'],
                              activation_height=r['activation_height'],
                              expiration_height=r['expiration_height'],
                              effective_amount=r['effective_amount'],
                              support_amount=r['support_amount'],
                              last_takeover_height=r['last_take_over_height'],
                              claims_in_channel=r['claims_in_channel'],
                              channel_hash=r['channel_hash'],
                              reposted_claim_hash=r['reposted_claim_hash'],
                              reposted=r['reposted'],
                              signature_valid=r['signature_valid'])
                for r in response
            ]
            extra = [
                ResolveResult(name=r['claim_name'],
                              normalized_name=r['normalized_name'],
                              claim_hash=r['claim_hash'],
                              tx_num=r['tx_num'],
                              position=r['tx_nout'],
                              tx_hash=r['tx_hash'],
                              height=r['height'],
                              amount=r['amount'],
                              short_url=r['short_url'],
                              is_controlling=r['is_controlling'],
                              canonical_url=r['canonical_url'],
                              creation_height=r['creation_height'],
                              activation_height=r['activation_height'],
                              expiration_height=r['expiration_height'],
                              effective_amount=r['effective_amount'],
                              support_amount=r['support_amount'],
                              last_takeover_height=r['last_take_over_height'],
                              claims_in_channel=r['claims_in_channel'],
                              channel_hash=r['channel_hash'],
                              reposted_claim_hash=r['reposted_claim_hash'],
                              reposted=r['reposted'],
                              signature_valid=r['signature_valid'])
                for r in await self._get_referenced_rows(total_referenced)
            ]
            result = Outputs.to_base64(response, extra, offset, total, censor)
            cache_item.result = result
            return result

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [
            claim_id for claim_id in claim_ids
            if self.claim_cache.get(claim_id) is None
        ]
        if missing:
            results = await self.search_client.mget(index=self.index,
                                                    body={"ids": missing})
            for result in expand_result(
                    filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def search(self, **kwargs):
        try:
            return await self.search_ahead(**kwargs)
        except NotFoundError:
            return [], 0, 0
        # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def search_ahead(self, **kwargs):
        # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
        per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
        remove_duplicates = kwargs.pop('remove_duplicates', False)
        page_size = kwargs.pop('limit', 10)
        offset = kwargs.pop('offset', 0)
        kwargs['limit'] = 1000
        cache_item = ResultCacheItem.from_cache(
            f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
        if cache_item.result is not None:
            reordered_hits = cache_item.result
        else:
            async with cache_item.lock:
                if cache_item.result:
                    reordered_hits = cache_item.result
                else:
                    query = expand_query(**kwargs)
                    search_hits = deque(
                        (await
                         self.search_client.search(query,
                                                   index=self.index,
                                                   track_total_hits=False,
                                                   _source_includes=[
                                                       '_id', 'channel_id',
                                                       'reposted_claim_id',
                                                       'creation_height'
                                                   ]))['hits']['hits'])
                    if remove_duplicates:
                        search_hits = self.__remove_duplicates(search_hits)
                    if per_channel_per_page > 0:
                        reordered_hits = self.__search_ahead(
                            search_hits, page_size, per_channel_per_page)
                    else:
                        reordered_hits = [(hit['_id'],
                                           hit['_source']['channel_id'])
                                          for hit in search_hits]
                    cache_item.result = reordered_hits
        result = list(await self.get_many(
            *(claim_id
              for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
        return result, 0, len(reordered_hits)

    def __remove_duplicates(self, search_hits: deque) -> deque:
        known_ids = {
        }  # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
        dropped = set()
        for hit in search_hits:
            hit_height, hit_id = hit['_source']['creation_height'], hit[
                '_source']['reposted_claim_id'] or hit['_id']
            if hit_id not in known_ids:
                known_ids[hit_id] = (hit_height, hit['_id'])
            else:
                previous_height, previous_id = known_ids[hit_id]
                if hit_height < previous_height:
                    known_ids[hit_id] = (hit_height, hit['_id'])
                    dropped.add(previous_id)
                else:
                    dropped.add(hit['_id'])
        return deque(hit for hit in search_hits if hit['_id'] not in dropped)

    def __search_ahead(self, search_hits: list, page_size: int,
                       per_channel_per_page: int):
        reordered_hits = []
        channel_counters = Counter()
        next_page_hits_maybe_check_later = deque()
        while search_hits or next_page_hits_maybe_check_later:
            if reordered_hits and len(reordered_hits) % page_size == 0:
                channel_counters.clear()
            elif not reordered_hits:
                pass
            else:
                break  # means last page was incomplete and we are left with bad replacements
            for _ in range(len(next_page_hits_maybe_check_later)):
                claim_id, channel_id = next_page_hits_maybe_check_later.popleft(
                )
                if per_channel_per_page > 0 and channel_counters[
                        channel_id] < per_channel_per_page:
                    reordered_hits.append((claim_id, channel_id))
                    channel_counters[channel_id] += 1
                else:
                    next_page_hits_maybe_check_later.append(
                        (claim_id, channel_id))
            while search_hits:
                hit = search_hits.popleft()
                hit_id, hit_channel_id = hit['_id'], hit['_source'][
                    'channel_id']
                if hit_channel_id is None or per_channel_per_page <= 0:
                    reordered_hits.append((hit_id, hit_channel_id))
                elif channel_counters[hit_channel_id] < per_channel_per_page:
                    reordered_hits.append((hit_id, hit_channel_id))
                    channel_counters[hit_channel_id] += 1
                    if len(reordered_hits) % page_size == 0:
                        break
                else:
                    next_page_hits_maybe_check_later.append(
                        (hit_id, hit_channel_id))
        return reordered_hits

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(
            filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(
            filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(
            filter(None, (row['censoring_channel_id'] for row in txo_rows)))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(
                filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos
コード例 #3
0
ファイル: search.py プロジェクト: belikor/lbry-sdk
class SearchIndex:
    VERSION = 1

    def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2 ** 15)
        self.short_id_cache = LRUCache(2 ** 17)
        self.search_cache = LRUCache(2 ** 17)
        self.resolution_cache = LRUCache(2 ** 17)
        self._elastic_host = elastic_host
        self._elastic_port = elastic_port

    async def get_index_version(self) -> int:
        try:
            template = await self.sync_client.indices.get_template(self.index)
            return template[self.index]['version']
        except NotFoundError:
            return 0

    async def set_index_version(self, version):
        await self.sync_client.indices.put_template(
            self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400
        )

    async def start(self) -> bool:
        if self.sync_client:
            return False
        hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
        self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)

        res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
        acked = res.get('acknowledged', False)
        if acked:
            await self.set_index_version(self.VERSION)
            return acked
        index_version = await self.get_index_version()
        if index_version != self.VERSION:
            self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION)
            raise IndexVersionMismatch(index_version, self.VERSION)
        return acked

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index, ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield extract_doc(doc, self.index)
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        self.logger.info("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_id=params[ctx._source.{key}]",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.clear_caches()

    def clear_caches(self):
        self.search_cache.clear()
        self.short_id_cache.clear()
        self.claim_cache.clear()
        self.resolution_cache.clear()

    async def session_query(self, query_name, kwargs):
        offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
        total_referenced = []
        if query_name == 'resolve':
            total_referenced, response, censor = await self.resolve(*kwargs)
        else:
            cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
            if cache_item.result is not None:
                return cache_item.result
            async with cache_item.lock:
                if cache_item.result:
                    return cache_item.result
                censor = Censor(Censor.SEARCH)
                if kwargs.get('no_totals'):
                    response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                else:
                    response, offset, total = await self.search(**kwargs)
                censor.apply(response)
                total_referenced.extend(response)
                if censor.censored:
                    response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                    total_referenced.extend(response)
                result = Outputs.to_base64(
                    response, await self._get_referenced_rows(total_referenced), offset, total, censor
                )
                cache_item.result = result
                return result
        return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor)

    async def resolve(self, *urls):
        censor = Censor(Censor.RESOLVE)
        results = [await self.resolve_url(url) for url in urls]
        # just heat the cache
        await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
        results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]

        censored = [
            result if not isinstance(result, dict) or not censor.censor(result)
            else ResolveCensoredError(url, result['censoring_channel_id'])
            for url, result in zip(urls, results)
        ]
        return results, censored, censor

    def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
        cached = self.claim_cache.get(resolution)
        return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
        if missing:
            results = await self.search_client.mget(
                index=self.index, body={"ids": missing}
            )
            for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def full_id_from_short_id(self, name, short_id, channel_id=None):
        key = '#'.join((channel_id or '', name, short_id))
        if key not in self.short_id_cache:
            query = {'name': name, 'claim_id': short_id}
            if channel_id:
                query['channel_id'] = channel_id
                query['order_by'] = ['^channel_join']
                query['signature_valid'] = True
            else:
                query['order_by'] = '^creation_height'
            result, _, _ = await self.search(**query, limit=1)
            if len(result) == 1:
                result = result[0]['claim_id']
                self.short_id_cache[key] = result
        return self.short_id_cache.get(key, None)

    async def search(self, **kwargs):
        if 'channel' in kwargs:
            kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel'))
            if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
                return [], 0, 0
        try:
            return await self.search_ahead(**kwargs)
        except NotFoundError:
            return [], 0, 0
        return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def search_ahead(self, **kwargs):
        # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
        per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
        remove_duplicates = kwargs.pop('remove_duplicates', False)
        page_size = kwargs.pop('limit', 10)
        offset = kwargs.pop('offset', 0)
        kwargs['limit'] = 1000
        cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
        if cache_item.result is not None:
            reordered_hits = cache_item.result
        else:
            async with cache_item.lock:
                if cache_item.result:
                    reordered_hits = cache_item.result
                else:
                    query = expand_query(**kwargs)
                    search_hits = deque((await self.search_client.search(
                        query, index=self.index, track_total_hits=False,
                        _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
                    ))['hits']['hits'])
                    if remove_duplicates:
                        search_hits = self.__remove_duplicates(search_hits)
                    if per_channel_per_page > 0:
                        reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
                    else:
                        reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
                    cache_item.result = reordered_hits
        result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
        return result, 0, len(reordered_hits)

    def __remove_duplicates(self, search_hits: deque) -> deque:
        known_ids = {}  # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
        dropped = set()
        for hit in search_hits:
            hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
            if hit_id not in known_ids:
                known_ids[hit_id] = (hit_height, hit['_id'])
            else:
                previous_height, previous_id = known_ids[hit_id]
                if hit_height < previous_height:
                    known_ids[hit_id] = (hit_height, hit['_id'])
                    dropped.add(previous_id)
                else:
                    dropped.add(hit['_id'])
        return deque(hit for hit in search_hits if hit['_id'] not in dropped)

    def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
        reordered_hits = []
        channel_counters = Counter()
        next_page_hits_maybe_check_later = deque()
        while search_hits or next_page_hits_maybe_check_later:
            if reordered_hits and len(reordered_hits) % page_size == 0:
                channel_counters.clear()
            elif not reordered_hits:
                pass
            else:
                break  # means last page was incomplete and we are left with bad replacements
            for _ in range(len(next_page_hits_maybe_check_later)):
                claim_id, channel_id = next_page_hits_maybe_check_later.popleft()
                if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page:
                    reordered_hits.append((claim_id, channel_id))
                    channel_counters[channel_id] += 1
                else:
                    next_page_hits_maybe_check_later.append((claim_id, channel_id))
            while search_hits:
                hit = search_hits.popleft()
                hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
                if hit_channel_id is None or per_channel_per_page <= 0:
                    reordered_hits.append((hit_id, hit_channel_id))
                elif channel_counters[hit_channel_id] < per_channel_per_page:
                    reordered_hits.append((hit_id, hit_channel_id))
                    channel_counters[hit_channel_id] += 1
                    if len(reordered_hits) % page_size == 0:
                        break
                else:
                    next_page_hits_maybe_check_later.append((hit_id, hit_channel_id))
        return reordered_hits

    async def resolve_url(self, raw_url):
        if raw_url not in self.resolution_cache:
            self.resolution_cache[raw_url] = await self._resolve_url(raw_url)
        return self.resolution_cache[raw_url]

    async def _resolve_url(self, raw_url):
        try:
            url = URL.parse(raw_url)
        except ValueError as e:
            return e

        stream = LookupError(f'Could not find claim at "{raw_url}".')

        channel_id = await self.resolve_channel_id(url)
        if isinstance(channel_id, LookupError):
            return channel_id
        stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
        if url.has_stream:
            return StreamResolution(stream)
        else:
            return ChannelResolution(channel_id)

    async def resolve_channel_id(self, url: URL):
        if not url.has_channel:
            return
        if url.channel.is_fullid:
            return url.channel.claim_id
        if url.channel.is_shortid:
            channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id)
            if not channel_id:
                return LookupError(f'Could not find channel in "{url}".')
            return channel_id

        query = url.channel.to_dict()
        if set(query) == {'name'}:
            query['is_controlling'] = True
        else:
            query['order_by'] = ['^creation_height']
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            channel_id = matches[0]['claim_id']
        else:
            return LookupError(f'Could not find channel in "{url}".')
        return channel_id

    async def resolve_stream(self, url: URL, channel_id: str = None):
        if not url.has_stream:
            return None
        if url.has_channel and channel_id is None:
            return None
        query = url.stream.to_dict()
        if url.stream.claim_id is not None:
            if url.stream.is_fullid:
                claim_id = url.stream.claim_id
            else:
                claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
            return claim_id

        if channel_id is not None:
            if set(query) == {'name'}:
                # temporarily emulate is_controlling for claims in channel
                query['order_by'] = ['effective_amount', '^height']
            else:
                query['order_by'] = ['^channel_join']
            query['channel_id'] = channel_id
            query['signature_valid'] = True
        elif set(query) == {'name'}:
            query['is_controlling'] = True
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            return matches[0]['claim_id']

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows)))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos