Exemple #1
0
 def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200):
     self.search_timeout = search_timeout
     self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
     self.search_client: Optional[AsyncElasticsearch] = None
     self.sync_client: Optional[AsyncElasticsearch] = None
     self.index = index_prefix + 'claims'
     self.logger = class_logger(__name__, self.__class__.__name__)
     self.claim_cache = LRUCache(2 ** 15)
     self.search_cache = LRUCache(2 ** 17)
     self._elastic_host = elastic_host
     self._elastic_port = elastic_port
Exemple #2
0
 def __init__(self, height: int, tip: bytes, throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10,
              allow_localhost: bool = False, allow_lan: bool = False):
     super().__init__()
     self.transport: Optional[asyncio.transports.DatagramTransport] = None
     self._height = height
     self._tip = tip
     self._flags = 0
     self._cached_response = None
     self.update_cached_response()
     self._throttle = LRUCache(throttle_cache_size)
     self._should_log = LRUCache(throttle_cache_size)
     self._min_delay = 1 / throttle_reqs_per_sec
     self._allow_localhost = allow_localhost
     self._allow_lan = allow_lan
Exemple #3
0
    def __init__(self, env):
        self.logger = util.class_logger(__name__, self.__class__.__name__)
        self.env = env
        self.coin = env.coin
        self.executor = None

        self.logger.info(f'switching current directory to {env.db_dir}')

        self.db_class = db_class(env.db_dir, self.env.db_engine)
        self.history = History()
        self.utxo_db = None
        self.tx_counts = None
        self.headers = None
        self.last_flush = time.time()

        self.logger.info(f'using {self.env.db_engine} for DB backend')

        # Header merkle cache
        self.merkle = Merkle()
        self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)

        self.headers_db = None
        self.tx_db = None

        self._tx_and_merkle_cache = LRUCache(2**17, metric_name='tx_and_merkle', namespace="wallet_server")
        self.total_transactions = None
Exemple #4
0
 def __init__(self, index_prefix: str, search_timeout=3.0):
     self.search_timeout = search_timeout
     self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
     self.search_client: Optional[AsyncElasticsearch] = None
     self.sync_client: Optional[AsyncElasticsearch] = None
     self.index = index_prefix + 'claims'
     self.logger = class_logger(__name__, self.__class__.__name__)
     self.claim_cache = LRUCache(2 ** 15)
     self.short_id_cache = LRUCache(2 ** 17)  # never invalidated, since short ids are forever
     self.search_cache = LRUCache(2 ** 17)
     self.resolution_cache = LRUCache(2 ** 17)
Exemple #5
0
 def __init__(self,
              coin,
              url,
              max_workqueue=10,
              init_retry=0.25,
              max_retry=4.0):
     self.coin = coin
     self.logger = class_logger(__name__, self.__class__.__name__)
     self.set_url(url)
     # Limit concurrent RPC calls to this number.
     # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
     self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
     self.init_retry = init_retry
     self.max_retry = max_retry
     self._height = None
     self.available_rpcs = {}
     self.connector = aiohttp.TCPConnector()
     self._block_hash_cache = LRUCache(100000)
     self._block_cache = LRUCache(2**16,
                                  metric_name='block',
                                  namespace=NAMESPACE)
Exemple #6
0
 def __init__(self, loop: asyncio.AbstractEventLoop):
     self._loop = loop
     self._rpc_failures: typing.Dict[
         typing.Tuple[str, int], typing.Tuple[typing.Optional[float], typing.Optional[float]]
     ] = LRUCache(CACHE_SIZE)
     self._last_replied: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE)
     self._last_sent: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE)
     self._last_requested: typing.Dict[typing.Tuple[str, int], float] = LRUCache(CACHE_SIZE)
     self._node_id_mapping: typing.Dict[typing.Tuple[str, int], bytes] = LRUCache(CACHE_SIZE)
     self._node_id_reverse_mapping: typing.Dict[bytes, typing.Tuple[str, int]] = LRUCache(CACHE_SIZE)
     self._node_tokens: typing.Dict[bytes, (float, bytes)] = LRUCache(CACHE_SIZE)
Exemple #7
0
class SearchIndex:
    VERSION = 1

    def __init__(self,
                 index_prefix: str,
                 search_timeout=3.0,
                 elastic_host='localhost',
                 elastic_port=9200):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2**15)
        self.search_cache = LRUCache(2**17)
        self._elastic_host = elastic_host
        self._elastic_port = elastic_port

    async def get_index_version(self) -> int:
        try:
            template = await self.sync_client.indices.get_template(self.index)
            return template[self.index]['version']
        except NotFoundError:
            return 0

    async def set_index_version(self, version):
        await self.sync_client.indices.put_template(self.index,
                                                    body={
                                                        'version':
                                                        version,
                                                        'index_patterns':
                                                        ['ignored']
                                                    },
                                                    ignore=400)

    async def start(self) -> bool:
        if self.sync_client:
            return False
        hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
        self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(hosts,
                                                timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning(
                    "Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)

        res = await self.sync_client.indices.create(self.index,
                                                    INDEX_DEFAULT_SETTINGS,
                                                    ignore=400)
        acked = res.get('acknowledged', False)
        if acked:
            await self.set_index_version(self.VERSION)
            return acked
        index_version = await self.get_index_version()
        if index_version != self.VERSION:
            self.logger.error(
                "es search index has an incompatible version: %s vs %s",
                index_version, self.VERSION)
            raise IndexVersionMismatch(index_version, self.VERSION)
        await self.sync_client.indices.refresh(self.index)
        return acked

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(
            asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index,
                                               ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        async for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield {
                    'doc': {
                        key: value
                        for key, value in doc.items() if key in ALL_FIELDS
                    },
                    '_id': doc['claim_id'],
                    '_index': self.index,
                    '_op_type': 'update',
                    'doc_as_upsert': True
                }
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        if count:
            self.logger.info("Indexing done for %d claims.", count)
        else:
            self.logger.debug("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(
                self.sync_client,
                self._consume_claim_producer(claim_producer),
                raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.debug("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {
            blocked.hex(): blocker.hex()
            for blocked, blocker in blockdict.items()
        }
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()),
                                  censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()),
                                  censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; "
            f"ctx._source.censoring_channel_id=params[ctx._source.{key}];",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def update_trending_score(self, params):
        update_trending_score_script = """
        double softenLBC(double lbc) { return (Math.pow(lbc, 1.0 / 3.0)); }

        double logsumexp(double x, double y)
        {
            double top;
            if(x > y)
                top = x;
            else
                top = y;
            double result = top + Math.log(Math.exp(x-top) + Math.exp(y-top));
            return(result);
        }

        double logdiffexp(double big, double small)
        {
            return big + Math.log(1.0 - Math.exp(small - big));
        }

        double squash(double x)
        {
            if(x < 0.0)
                return -Math.log(1.0 - x);
            else
                return Math.log(x + 1.0);
        }

        double unsquash(double x)
        {
            if(x < 0.0)
                return 1.0 - Math.exp(-x);
            else
                return Math.exp(x) - 1.0;
        }

        double log_to_squash(double x)
        {
            return logsumexp(x, 0.0);
        }

        double squash_to_log(double x)
        {
            //assert x > 0.0;
            return logdiffexp(x, 0.0);
        }

        double squashed_add(double x, double y)
        {
            // squash(unsquash(x) + unsquash(y)) but avoiding overflow.
            // Cases where the signs are the same
            if (x < 0.0 && y < 0.0)
                return -logsumexp(-x, logdiffexp(-y, 0.0));
            if (x >= 0.0 && y >= 0.0)
                return logsumexp(x, logdiffexp(y, 0.0));
            // Where the signs differ
            if (x >= 0.0 && y < 0.0)
                if (Math.abs(x) >= Math.abs(y))
                    return logsumexp(0.0, logdiffexp(x, -y));
                else
                    return -logsumexp(0.0, logdiffexp(-y, x));
            if (x < 0.0 && y >= 0.0)
            {
                // Addition is commutative, hooray for new math
                return squashed_add(y, x);
            }
            return 0.0;
        }

        double squashed_multiply(double x, double y)
        {
            // squash(unsquash(x)*unsquash(y)) but avoiding overflow.
            int sign;
            if(x*y >= 0.0)
                sign = 1;
            else
                sign = -1;
            return sign*logsumexp(squash_to_log(Math.abs(x))
                            + squash_to_log(Math.abs(y)), 0.0);
        }

        // Squashed inflated units
        double inflateUnits(int height) {
            double timescale = 576.0; // Half life of 400 = e-folding time of a day
                                      // by coincidence, so may as well go with it
            return log_to_squash(height / timescale);
        }

        double spikePower(double newAmount) {
            if (newAmount < 50.0) {
                return(0.5);
            } else if (newAmount < 85.0) {
                return(newAmount / 100.0);
            } else {
                return(0.85);
            }
        }

        double spikeMass(double oldAmount, double newAmount) {
            double softenedChange = softenLBC(Math.abs(newAmount - oldAmount));
            double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount));
            double power = spikePower(newAmount);
            if (oldAmount > newAmount) {
                -1.0 * Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
            } else {
                Math.pow(changeInSoftened, power) * Math.pow(softenedChange, 1.0 - power)
            }
        }
        for (i in params.src.changes) {
            double units = inflateUnits(i.height);
            if (ctx._source.trending_score == null) {
                ctx._source.trending_score = 0.0;
            }
            double bigSpike = squashed_multiply(units, squash(spikeMass(i.prev_amount, i.new_amount)));
            ctx._source.trending_score = squashed_add(ctx._source.trending_score, bigSpike);
        }
        """
        start = time.perf_counter()

        def producer():
            for claim_id, claim_updates in params.items():
                yield {
                    '_id': claim_id,
                    '_index': self.index,
                    '_op_type': 'update',
                    'script': {
                        'lang': 'painless',
                        'source': update_trending_score_script,
                        'params': {
                            'src': {
                                'changes': [{
                                    'height': p.height,
                                    'prev_amount': p.prev_amount / 1E8,
                                    'new_amount': p.new_amount / 1E8,
                                } for p in claim_updates]
                            }
                        }
                    },
                }

        if not params:
            return
        async for ok, item in async_streaming_bulk(self.sync_client,
                                                   producer(),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("updating trending failed for an item: %s",
                                    item)
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("updated trending scores in %ims",
                         int((time.perf_counter() - start) * 1000))

    async def apply_filters(self, blocked_streams, blocked_channels,
                            filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH, filtered_streams),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH,
                                              filtered_channels),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.SEARCH, filtered_channels,
                                              True),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE, blocked_streams),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE,
                                              blocked_channels),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index,
                body=self.update_filter_query(Censor.RESOLVE, blocked_channels,
                                              True),
                slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.clear_caches()

    def clear_caches(self):
        self.search_cache.clear()
        self.claim_cache.clear()

    async def cached_search(self, kwargs):
        total_referenced = []
        cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
        if cache_item.result is not None:
            return cache_item.result
        async with cache_item.lock:
            if cache_item.result:
                return cache_item.result
            censor = Censor(Censor.SEARCH)
            if kwargs.get('no_totals'):
                response, offset, total = await self.search(
                    **kwargs, censor_type=Censor.NOT_CENSORED)
            else:
                response, offset, total = await self.search(**kwargs)
            censor.apply(response)
            total_referenced.extend(response)

            if censor.censored:
                response, _, _ = await self.search(
                    **kwargs, censor_type=Censor.NOT_CENSORED)
                total_referenced.extend(response)
            response = [
                ResolveResult(name=r['claim_name'],
                              normalized_name=r['normalized_name'],
                              claim_hash=r['claim_hash'],
                              tx_num=r['tx_num'],
                              position=r['tx_nout'],
                              tx_hash=r['tx_hash'],
                              height=r['height'],
                              amount=r['amount'],
                              short_url=r['short_url'],
                              is_controlling=r['is_controlling'],
                              canonical_url=r['canonical_url'],
                              creation_height=r['creation_height'],
                              activation_height=r['activation_height'],
                              expiration_height=r['expiration_height'],
                              effective_amount=r['effective_amount'],
                              support_amount=r['support_amount'],
                              last_takeover_height=r['last_take_over_height'],
                              claims_in_channel=r['claims_in_channel'],
                              channel_hash=r['channel_hash'],
                              reposted_claim_hash=r['reposted_claim_hash'],
                              reposted=r['reposted'],
                              signature_valid=r['signature_valid'])
                for r in response
            ]
            extra = [
                ResolveResult(name=r['claim_name'],
                              normalized_name=r['normalized_name'],
                              claim_hash=r['claim_hash'],
                              tx_num=r['tx_num'],
                              position=r['tx_nout'],
                              tx_hash=r['tx_hash'],
                              height=r['height'],
                              amount=r['amount'],
                              short_url=r['short_url'],
                              is_controlling=r['is_controlling'],
                              canonical_url=r['canonical_url'],
                              creation_height=r['creation_height'],
                              activation_height=r['activation_height'],
                              expiration_height=r['expiration_height'],
                              effective_amount=r['effective_amount'],
                              support_amount=r['support_amount'],
                              last_takeover_height=r['last_take_over_height'],
                              claims_in_channel=r['claims_in_channel'],
                              channel_hash=r['channel_hash'],
                              reposted_claim_hash=r['reposted_claim_hash'],
                              reposted=r['reposted'],
                              signature_valid=r['signature_valid'])
                for r in await self._get_referenced_rows(total_referenced)
            ]
            result = Outputs.to_base64(response, extra, offset, total, censor)
            cache_item.result = result
            return result

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [
            claim_id for claim_id in claim_ids
            if self.claim_cache.get(claim_id) is None
        ]
        if missing:
            results = await self.search_client.mget(index=self.index,
                                                    body={"ids": missing})
            for result in expand_result(
                    filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def search(self, **kwargs):
        try:
            return await self.search_ahead(**kwargs)
        except NotFoundError:
            return [], 0, 0
        # return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def search_ahead(self, **kwargs):
        # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
        per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
        remove_duplicates = kwargs.pop('remove_duplicates', False)
        page_size = kwargs.pop('limit', 10)
        offset = kwargs.pop('offset', 0)
        kwargs['limit'] = 1000
        cache_item = ResultCacheItem.from_cache(
            f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
        if cache_item.result is not None:
            reordered_hits = cache_item.result
        else:
            async with cache_item.lock:
                if cache_item.result:
                    reordered_hits = cache_item.result
                else:
                    query = expand_query(**kwargs)
                    search_hits = deque(
                        (await
                         self.search_client.search(query,
                                                   index=self.index,
                                                   track_total_hits=False,
                                                   _source_includes=[
                                                       '_id', 'channel_id',
                                                       'reposted_claim_id',
                                                       'creation_height'
                                                   ]))['hits']['hits'])
                    if remove_duplicates:
                        search_hits = self.__remove_duplicates(search_hits)
                    if per_channel_per_page > 0:
                        reordered_hits = self.__search_ahead(
                            search_hits, page_size, per_channel_per_page)
                    else:
                        reordered_hits = [(hit['_id'],
                                           hit['_source']['channel_id'])
                                          for hit in search_hits]
                    cache_item.result = reordered_hits
        result = list(await self.get_many(
            *(claim_id
              for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
        return result, 0, len(reordered_hits)

    def __remove_duplicates(self, search_hits: deque) -> deque:
        known_ids = {
        }  # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
        dropped = set()
        for hit in search_hits:
            hit_height, hit_id = hit['_source']['creation_height'], hit[
                '_source']['reposted_claim_id'] or hit['_id']
            if hit_id not in known_ids:
                known_ids[hit_id] = (hit_height, hit['_id'])
            else:
                previous_height, previous_id = known_ids[hit_id]
                if hit_height < previous_height:
                    known_ids[hit_id] = (hit_height, hit['_id'])
                    dropped.add(previous_id)
                else:
                    dropped.add(hit['_id'])
        return deque(hit for hit in search_hits if hit['_id'] not in dropped)

    def __search_ahead(self, search_hits: list, page_size: int,
                       per_channel_per_page: int):
        reordered_hits = []
        channel_counters = Counter()
        next_page_hits_maybe_check_later = deque()
        while search_hits or next_page_hits_maybe_check_later:
            if reordered_hits and len(reordered_hits) % page_size == 0:
                channel_counters.clear()
            elif not reordered_hits:
                pass
            else:
                break  # means last page was incomplete and we are left with bad replacements
            for _ in range(len(next_page_hits_maybe_check_later)):
                claim_id, channel_id = next_page_hits_maybe_check_later.popleft(
                )
                if per_channel_per_page > 0 and channel_counters[
                        channel_id] < per_channel_per_page:
                    reordered_hits.append((claim_id, channel_id))
                    channel_counters[channel_id] += 1
                else:
                    next_page_hits_maybe_check_later.append(
                        (claim_id, channel_id))
            while search_hits:
                hit = search_hits.popleft()
                hit_id, hit_channel_id = hit['_id'], hit['_source'][
                    'channel_id']
                if hit_channel_id is None or per_channel_per_page <= 0:
                    reordered_hits.append((hit_id, hit_channel_id))
                elif channel_counters[hit_channel_id] < per_channel_per_page:
                    reordered_hits.append((hit_id, hit_channel_id))
                    channel_counters[hit_channel_id] += 1
                    if len(reordered_hits) % page_size == 0:
                        break
                else:
                    next_page_hits_maybe_check_later.append(
                        (hit_id, hit_channel_id))
        return reordered_hits

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(
            filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(
            filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(
            filter(None, (row['censoring_channel_id'] for row in txo_rows)))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(
                filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos
Exemple #8
0
class SearchIndex:
    def __init__(self, index_prefix: str, search_timeout=3.0):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2 ** 15)
        self.short_id_cache = LRUCache(2 ** 17)  # never invalidated, since short ids are forever
        self.search_cache = LRUCache(2 ** 17)
        self.resolution_cache = LRUCache(2 ** 17)

    async def start(self):
        if self.sync_client:
            return
        self.sync_client = AsyncElasticsearch(timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)
        res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
        return res.get('acknowledged', False)

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index, ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield extract_doc(doc, self.index)
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        self.logger.info("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_hash=params[ctx._source.{key}]",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.search_cache.clear()
        self.claim_cache.clear()
        self.resolution_cache.clear()

    async def session_query(self, query_name, kwargs):
        offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
        total_referenced = []
        if query_name == 'resolve':
            total_referenced, response, censor = await self.resolve(*kwargs)
        else:
            cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
            if cache_item.result is not None:
                return cache_item.result
            async with cache_item.lock:
                if cache_item.result:
                    return cache_item.result
                censor = Censor(Censor.SEARCH)
                if kwargs.get('no_totals'):
                    response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                else:
                    response, offset, total = await self.search(**kwargs)
                censor.apply(response)
                total_referenced.extend(response)
                if censor.censored:
                    response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                    total_referenced.extend(response)
                result = Outputs.to_base64(
                    response, await self._get_referenced_rows(total_referenced), offset, total, censor
                )
                cache_item.result = result
                return result
        return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor)

    async def resolve(self, *urls):
        censor = Censor(Censor.RESOLVE)
        results = [await self.resolve_url(url) for url in urls]
        # just heat the cache
        await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
        results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]

        censored = [
            result if not isinstance(result, dict) or not censor.censor(result)
            else ResolveCensoredError(url, result['censoring_channel_hash'])
            for url, result in zip(urls, results)
        ]
        return results, censored, censor

    def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
        cached = self.claim_cache.get(resolution)
        return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
        if missing:
            results = await self.search_client.mget(
                index=self.index, body={"ids": missing}
            )
            for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def full_id_from_short_id(self, name, short_id, channel_id=None):
        key = (channel_id or '') + name + short_id
        if key not in self.short_id_cache:
            query = {'name': name, 'claim_id': short_id}
            if channel_id:
                query['channel_id'] = channel_id
                query['order_by'] = ['^channel_join']
                query['signature_valid'] = True
            else:
                query['order_by'] = '^creation_height'
            result, _, _ = await self.search(**query, limit=1)
            if len(result) == 1:
                result = result[0]['claim_id']
                self.short_id_cache[key] = result
        return self.short_id_cache.get(key, None)

    async def search(self, **kwargs):
        if 'channel' in kwargs:
            kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel'))
            if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
                return [], 0, 0
        try:
            result = (await self.search_client.search(
                expand_query(**kwargs), index=self.index, track_total_hits=False if kwargs.get('no_totals') else 10_000
            ))['hits']
        except NotFoundError:
            return [], 0, 0
        return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def resolve_url(self, raw_url):
        if raw_url not in self.resolution_cache:
            self.resolution_cache[raw_url] = await self._resolve_url(raw_url)
        return self.resolution_cache[raw_url]

    async def _resolve_url(self, raw_url):
        try:
            url = URL.parse(raw_url)
        except ValueError as e:
            return e

        stream = LookupError(f'Could not find claim at "{raw_url}".')

        channel_id = await self.resolve_channel_id(url)
        if isinstance(channel_id, LookupError):
            return channel_id
        stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
        if url.has_stream:
            return StreamResolution(stream)
        else:
            return ChannelResolution(channel_id)

    async def resolve_channel_id(self, url: URL):
        if not url.has_channel:
            return
        if url.channel.is_fullid:
            return url.channel.claim_id
        if url.channel.is_shortid:
            channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id)
            if not channel_id:
                return LookupError(f'Could not find channel in "{url}".')
            return channel_id

        query = url.channel.to_dict()
        if set(query) == {'name'}:
            query['is_controlling'] = True
        else:
            query['order_by'] = ['^creation_height']
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            channel_id = matches[0]['claim_id']
        else:
            return LookupError(f'Could not find channel in "{url}".')
        return channel_id

    async def resolve_stream(self, url: URL, channel_id: str = None):
        if not url.has_stream:
            return None
        if url.has_channel and channel_id is None:
            return None
        query = url.stream.to_dict()
        if url.stream.claim_id is not None:
            if url.stream.is_fullid:
                claim_id = url.stream.claim_id
            else:
                claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
            return claim_id

        if channel_id is not None:
            if set(query) == {'name'}:
                # temporarily emulate is_controlling for claims in channel
                query['order_by'] = ['effective_amount', '^height']
            else:
                query['order_by'] = ['^channel_join']
            query['channel_id'] = channel_id
            query['signature_valid'] = True
        elif set(query) == {'name'}:
            query['is_controlling'] = True
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            return matches[0]['claim_id']

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(map(parse_claim_id, filter(None, (row['censoring_channel_hash'] for row in txo_rows))))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos
    def __init__(self, loop: asyncio.BaseEventLoop, blob_dir: str, storage: 'SQLiteStorage', config: 'Config',
                 node_data_store: typing.Optional['DictDataStore'] = None):
        """
        This class stores blobs on the hard disk

        blob_dir - directory where blobs are stored
        storage - SQLiteStorage object
        """
        self.loop = loop
        self.blob_dir = blob_dir
        self.storage = storage
        self._node_data_store = node_data_store
        self.completed_blob_hashes: typing.Set[str] = set() if not self._node_data_store\
            else self._node_data_store.completed_blobs
        self.blobs: typing.Dict[str, AbstractBlob] = {}
        self.config = config
        self.decrypted_blob_lru_cache = None if not self.config.blob_lru_cache_size else LRUCache(
            self.config.blob_lru_cache_size)
        self.connection_manager = ConnectionManager(loop)
Exemple #10
0
class SearchIndex:
    VERSION = 1

    def __init__(self, index_prefix: str, search_timeout=3.0, elastic_host='localhost', elastic_port=9200):
        self.search_timeout = search_timeout
        self.sync_timeout = 600  # wont hit that 99% of the time, but can hit on a fresh import
        self.search_client: Optional[AsyncElasticsearch] = None
        self.sync_client: Optional[AsyncElasticsearch] = None
        self.index = index_prefix + 'claims'
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.claim_cache = LRUCache(2 ** 15)
        self.short_id_cache = LRUCache(2 ** 17)
        self.search_cache = LRUCache(2 ** 17)
        self.resolution_cache = LRUCache(2 ** 17)
        self._elastic_host = elastic_host
        self._elastic_port = elastic_port

    async def get_index_version(self) -> int:
        try:
            template = await self.sync_client.indices.get_template(self.index)
            return template[self.index]['version']
        except NotFoundError:
            return 0

    async def set_index_version(self, version):
        await self.sync_client.indices.put_template(
            self.index, body={'version': version, 'index_patterns': ['ignored']}, ignore=400
        )

    async def start(self) -> bool:
        if self.sync_client:
            return False
        hosts = [{'host': self._elastic_host, 'port': self._elastic_port}]
        self.sync_client = AsyncElasticsearch(hosts, timeout=self.sync_timeout)
        self.search_client = AsyncElasticsearch(hosts, timeout=self.search_timeout)
        while True:
            try:
                await self.sync_client.cluster.health(wait_for_status='yellow')
                break
            except ConnectionError:
                self.logger.warning("Failed to connect to Elasticsearch. Waiting for it!")
                await asyncio.sleep(1)

        res = await self.sync_client.indices.create(self.index, INDEX_DEFAULT_SETTINGS, ignore=400)
        acked = res.get('acknowledged', False)
        if acked:
            await self.set_index_version(self.VERSION)
            return acked
        index_version = await self.get_index_version()
        if index_version != self.VERSION:
            self.logger.error("es search index has an incompatible version: %s vs %s", index_version, self.VERSION)
            raise IndexVersionMismatch(index_version, self.VERSION)
        return acked

    def stop(self):
        clients = [self.sync_client, self.search_client]
        self.sync_client, self.search_client = None, None
        return asyncio.ensure_future(asyncio.gather(*(client.close() for client in clients)))

    def delete_index(self):
        return self.sync_client.indices.delete(self.index, ignore_unavailable=True)

    async def _consume_claim_producer(self, claim_producer):
        count = 0
        for op, doc in claim_producer:
            if op == 'delete':
                yield {'_index': self.index, '_op_type': 'delete', '_id': doc}
            else:
                yield extract_doc(doc, self.index)
            count += 1
            if count % 100 == 0:
                self.logger.info("Indexing in progress, %d claims.", count)
        self.logger.info("Indexing done for %d claims.", count)

    async def claim_consumer(self, claim_producer):
        touched = set()
        async for ok, item in async_streaming_bulk(self.sync_client, self._consume_claim_producer(claim_producer),
                                                   raise_on_error=False):
            if not ok:
                self.logger.warning("indexing failed for an item: %s", item)
            else:
                item = item.popitem()[1]
                touched.add(item['_id'])
        await self.sync_client.indices.refresh(self.index)
        self.logger.info("Indexing done.")

    def update_filter_query(self, censor_type, blockdict, channels=False):
        blockdict = {key[::-1].hex(): value[::-1].hex() for key, value in blockdict.items()}
        if channels:
            update = expand_query(channel_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        else:
            update = expand_query(claim_id__in=list(blockdict.keys()), censor_type=f"<{censor_type}")
        key = 'channel_id' if channels else 'claim_id'
        update['script'] = {
            "source": f"ctx._source.censor_type={censor_type}; ctx._source.censoring_channel_id=params[ctx._source.{key}]",
            "lang": "painless",
            "params": blockdict
        }
        return update

    async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
        if filtered_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if filtered_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.SEARCH, filtered_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_streams:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_streams), slices=4)
            await self.sync_client.indices.refresh(self.index)
        if blocked_channels:
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels), slices=4)
            await self.sync_client.indices.refresh(self.index)
            await self.sync_client.update_by_query(
                self.index, body=self.update_filter_query(Censor.RESOLVE, blocked_channels, True), slices=4)
            await self.sync_client.indices.refresh(self.index)
        self.clear_caches()

    def clear_caches(self):
        self.search_cache.clear()
        self.short_id_cache.clear()
        self.claim_cache.clear()
        self.resolution_cache.clear()

    async def session_query(self, query_name, kwargs):
        offset, total = kwargs.get('offset', 0) if isinstance(kwargs, dict) else 0, 0
        total_referenced = []
        if query_name == 'resolve':
            total_referenced, response, censor = await self.resolve(*kwargs)
        else:
            cache_item = ResultCacheItem.from_cache(str(kwargs), self.search_cache)
            if cache_item.result is not None:
                return cache_item.result
            async with cache_item.lock:
                if cache_item.result:
                    return cache_item.result
                censor = Censor(Censor.SEARCH)
                if kwargs.get('no_totals'):
                    response, offset, total = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                else:
                    response, offset, total = await self.search(**kwargs)
                censor.apply(response)
                total_referenced.extend(response)
                if censor.censored:
                    response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED)
                    total_referenced.extend(response)
                result = Outputs.to_base64(
                    response, await self._get_referenced_rows(total_referenced), offset, total, censor
                )
                cache_item.result = result
                return result
        return Outputs.to_base64(response, await self._get_referenced_rows(total_referenced), offset, total, censor)

    async def resolve(self, *urls):
        censor = Censor(Censor.RESOLVE)
        results = [await self.resolve_url(url) for url in urls]
        # just heat the cache
        await self.populate_claim_cache(*filter(lambda x: isinstance(x, str), results))
        results = [self._get_from_cache_or_error(url, result) for url, result in zip(urls, results)]

        censored = [
            result if not isinstance(result, dict) or not censor.censor(result)
            else ResolveCensoredError(url, result['censoring_channel_id'])
            for url, result in zip(urls, results)
        ]
        return results, censored, censor

    def _get_from_cache_or_error(self, url: str, resolution: Union[LookupError, StreamResolution, ChannelResolution]):
        cached = self.claim_cache.get(resolution)
        return cached or (resolution if isinstance(resolution, LookupError) else resolution.lookup_error(url))

    async def get_many(self, *claim_ids):
        await self.populate_claim_cache(*claim_ids)
        return filter(None, map(self.claim_cache.get, claim_ids))

    async def populate_claim_cache(self, *claim_ids):
        missing = [claim_id for claim_id in claim_ids if self.claim_cache.get(claim_id) is None]
        if missing:
            results = await self.search_client.mget(
                index=self.index, body={"ids": missing}
            )
            for result in expand_result(filter(lambda doc: doc['found'], results["docs"])):
                self.claim_cache.set(result['claim_id'], result)

    async def full_id_from_short_id(self, name, short_id, channel_id=None):
        key = '#'.join((channel_id or '', name, short_id))
        if key not in self.short_id_cache:
            query = {'name': name, 'claim_id': short_id}
            if channel_id:
                query['channel_id'] = channel_id
                query['order_by'] = ['^channel_join']
                query['signature_valid'] = True
            else:
                query['order_by'] = '^creation_height'
            result, _, _ = await self.search(**query, limit=1)
            if len(result) == 1:
                result = result[0]['claim_id']
                self.short_id_cache[key] = result
        return self.short_id_cache.get(key, None)

    async def search(self, **kwargs):
        if 'channel' in kwargs:
            kwargs['channel_id'] = await self.resolve_url(kwargs.pop('channel'))
            if not kwargs['channel_id'] or not isinstance(kwargs['channel_id'], str):
                return [], 0, 0
        try:
            return await self.search_ahead(**kwargs)
        except NotFoundError:
            return [], 0, 0
        return expand_result(result['hits']), 0, result.get('total', {}).get('value', 0)

    async def search_ahead(self, **kwargs):
        # 'limit_claims_per_channel' case. Fetch 1000 results, reorder, slice, inflate and return
        per_channel_per_page = kwargs.pop('limit_claims_per_channel', 0) or 0
        remove_duplicates = kwargs.pop('remove_duplicates', False)
        page_size = kwargs.pop('limit', 10)
        offset = kwargs.pop('offset', 0)
        kwargs['limit'] = 1000
        cache_item = ResultCacheItem.from_cache(f"ahead{per_channel_per_page}{kwargs}", self.search_cache)
        if cache_item.result is not None:
            reordered_hits = cache_item.result
        else:
            async with cache_item.lock:
                if cache_item.result:
                    reordered_hits = cache_item.result
                else:
                    query = expand_query(**kwargs)
                    search_hits = deque((await self.search_client.search(
                        query, index=self.index, track_total_hits=False,
                        _source_includes=['_id', 'channel_id', 'reposted_claim_id', 'creation_height']
                    ))['hits']['hits'])
                    if remove_duplicates:
                        search_hits = self.__remove_duplicates(search_hits)
                    if per_channel_per_page > 0:
                        reordered_hits = self.__search_ahead(search_hits, page_size, per_channel_per_page)
                    else:
                        reordered_hits = [(hit['_id'], hit['_source']['channel_id']) for hit in search_hits]
                    cache_item.result = reordered_hits
        result = list(await self.get_many(*(claim_id for claim_id, _ in reordered_hits[offset:(offset + page_size)])))
        return result, 0, len(reordered_hits)

    def __remove_duplicates(self, search_hits: deque) -> deque:
        known_ids = {}  # claim_id -> (creation_height, hit_id), where hit_id is either reposted claim id or original
        dropped = set()
        for hit in search_hits:
            hit_height, hit_id = hit['_source']['creation_height'], hit['_source']['reposted_claim_id'] or hit['_id']
            if hit_id not in known_ids:
                known_ids[hit_id] = (hit_height, hit['_id'])
            else:
                previous_height, previous_id = known_ids[hit_id]
                if hit_height < previous_height:
                    known_ids[hit_id] = (hit_height, hit['_id'])
                    dropped.add(previous_id)
                else:
                    dropped.add(hit['_id'])
        return deque(hit for hit in search_hits if hit['_id'] not in dropped)

    def __search_ahead(self, search_hits: list, page_size: int, per_channel_per_page: int):
        reordered_hits = []
        channel_counters = Counter()
        next_page_hits_maybe_check_later = deque()
        while search_hits or next_page_hits_maybe_check_later:
            if reordered_hits and len(reordered_hits) % page_size == 0:
                channel_counters.clear()
            elif not reordered_hits:
                pass
            else:
                break  # means last page was incomplete and we are left with bad replacements
            for _ in range(len(next_page_hits_maybe_check_later)):
                claim_id, channel_id = next_page_hits_maybe_check_later.popleft()
                if per_channel_per_page > 0 and channel_counters[channel_id] < per_channel_per_page:
                    reordered_hits.append((claim_id, channel_id))
                    channel_counters[channel_id] += 1
                else:
                    next_page_hits_maybe_check_later.append((claim_id, channel_id))
            while search_hits:
                hit = search_hits.popleft()
                hit_id, hit_channel_id = hit['_id'], hit['_source']['channel_id']
                if hit_channel_id is None or per_channel_per_page <= 0:
                    reordered_hits.append((hit_id, hit_channel_id))
                elif channel_counters[hit_channel_id] < per_channel_per_page:
                    reordered_hits.append((hit_id, hit_channel_id))
                    channel_counters[hit_channel_id] += 1
                    if len(reordered_hits) % page_size == 0:
                        break
                else:
                    next_page_hits_maybe_check_later.append((hit_id, hit_channel_id))
        return reordered_hits

    async def resolve_url(self, raw_url):
        if raw_url not in self.resolution_cache:
            self.resolution_cache[raw_url] = await self._resolve_url(raw_url)
        return self.resolution_cache[raw_url]

    async def _resolve_url(self, raw_url):
        try:
            url = URL.parse(raw_url)
        except ValueError as e:
            return e

        stream = LookupError(f'Could not find claim at "{raw_url}".')

        channel_id = await self.resolve_channel_id(url)
        if isinstance(channel_id, LookupError):
            return channel_id
        stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream
        if url.has_stream:
            return StreamResolution(stream)
        else:
            return ChannelResolution(channel_id)

    async def resolve_channel_id(self, url: URL):
        if not url.has_channel:
            return
        if url.channel.is_fullid:
            return url.channel.claim_id
        if url.channel.is_shortid:
            channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id)
            if not channel_id:
                return LookupError(f'Could not find channel in "{url}".')
            return channel_id

        query = url.channel.to_dict()
        if set(query) == {'name'}:
            query['is_controlling'] = True
        else:
            query['order_by'] = ['^creation_height']
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            channel_id = matches[0]['claim_id']
        else:
            return LookupError(f'Could not find channel in "{url}".')
        return channel_id

    async def resolve_stream(self, url: URL, channel_id: str = None):
        if not url.has_stream:
            return None
        if url.has_channel and channel_id is None:
            return None
        query = url.stream.to_dict()
        if url.stream.claim_id is not None:
            if url.stream.is_fullid:
                claim_id = url.stream.claim_id
            else:
                claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id)
            return claim_id

        if channel_id is not None:
            if set(query) == {'name'}:
                # temporarily emulate is_controlling for claims in channel
                query['order_by'] = ['effective_amount', '^height']
            else:
                query['order_by'] = ['^channel_join']
            query['channel_id'] = channel_id
            query['signature_valid'] = True
        elif set(query) == {'name'}:
            query['is_controlling'] = True
        matches, _, _ = await self.search(**query, limit=1)
        if matches:
            return matches[0]['claim_id']

    async def _get_referenced_rows(self, txo_rows: List[dict]):
        txo_rows = [row for row in txo_rows if isinstance(row, dict)]
        referenced_ids = set(filter(None, map(itemgetter('reposted_claim_id'), txo_rows)))
        referenced_ids |= set(filter(None, (row['channel_id'] for row in txo_rows)))
        referenced_ids |= set(filter(None, (row['censoring_channel_id'] for row in txo_rows)))

        referenced_txos = []
        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))
            referenced_ids = set(filter(None, (row['channel_id'] for row in referenced_txos)))

        if referenced_ids:
            referenced_txos.extend(await self.get_many(*referenced_ids))

        return referenced_txos
Exemple #11
0
class SPVServerStatusProtocol(asyncio.DatagramProtocol):
    PROTOCOL_VERSION = 1

    def __init__(self,
                 height: int,
                 tip: bytes,
                 throttle_cache_size: int = 1024,
                 throttle_reqs_per_sec: int = 10):
        super().__init__()
        self.transport: Optional[asyncio.transports.DatagramTransport] = None
        self._height = height
        self._tip = tip
        self._flags = 0
        self._cached_response = None
        self.update_cached_response()
        self._throttle = LRUCache(throttle_cache_size)
        self._should_log = LRUCache(throttle_cache_size)
        self._min_delay = 1 / throttle_reqs_per_sec

    def update_cached_response(self):
        self._cached_response = SPVPong.make(self._height, self._tip,
                                             self._flags,
                                             self.PROTOCOL_VERSION)

    def set_unavailable(self):
        self._flags &= 0b11111110
        self.update_cached_response()

    def set_available(self):
        self._flags |= 0b00000001
        self.update_cached_response()

    def set_height(self, height: int, tip: bytes):
        self._height, self._tip = height, tip
        self.update_cached_response()

    def should_throttle(self, host: str):
        now = perf_counter()
        last_requested = self._throttle.get(host, default=0)
        self._throttle[host] = now
        if now - last_requested < self._min_delay:
            log_cnt = self._should_log.get(host, default=0) + 1
            if log_cnt % 100 == 0:
                log.warning("throttle spv status to %s", host)
            self._should_log[host] = log_cnt
            return True
        return False

    def make_pong(self, host):
        return self._cached_response + bytes(int(b) for b in host.split("."))

    def datagram_received(self, data: bytes, addr: Tuple[str, int]):
        if self.should_throttle(addr[0]):
            return
        try:
            SPVPing.decode(data)
        except (ValueError, struct.error, AttributeError, TypeError):
            # log.exception("derp")
            return
        self.transport.sendto(self.make_pong(addr[0]), addr)
        # ping_count_metric.inc()

    def connection_made(self, transport) -> None:
        self.transport = transport

    def connection_lost(self, exc: Optional[Exception]) -> None:
        self.transport = None

    def close(self):
        if self.transport:
            self.transport.close()
Exemple #12
0
class Daemon:
    """Handles connections to a daemon at the given URL."""

    WARMING_UP = -28
    id_counter = itertools.count()

    lbrycrd_request_time_metric = Histogram("lbrycrd_request",
                                            "lbrycrd requests count",
                                            namespace=NAMESPACE,
                                            labelnames=("method", ))
    lbrycrd_pending_count_metric = Gauge(
        "lbrycrd_pending_count",
        "Number of lbrycrd rpcs that are in flight",
        namespace=NAMESPACE,
        labelnames=("method", ))

    def __init__(self,
                 coin,
                 url,
                 max_workqueue=10,
                 init_retry=0.25,
                 max_retry=4.0):
        self.coin = coin
        self.logger = class_logger(__name__, self.__class__.__name__)
        self.set_url(url)
        # Limit concurrent RPC calls to this number.
        # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
        self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
        self.init_retry = init_retry
        self.max_retry = max_retry
        self._height = None
        self.available_rpcs = {}
        self.connector = aiohttp.TCPConnector()
        self._block_hash_cache = LRUCache(100000)
        self._block_cache = LRUCache(2**16,
                                     metric_name='block',
                                     namespace=NAMESPACE)

    async def close(self):
        if self.connector:
            await self.connector.close()
            self.connector = None

    def set_url(self, url):
        """Set the URLS to the given list, and switch to the first one."""
        urls = url.split(',')
        urls = [self.coin.sanitize_url(url) for url in urls]
        for n, url in enumerate(urls):
            status = '' if n else ' (current)'
            logged_url = self.logged_url(url)
            self.logger.info(f'daemon #{n + 1} at {logged_url}{status}')
        self.url_index = 0
        self.urls = urls

    def current_url(self):
        """Returns the current daemon URL."""
        return self.urls[self.url_index]

    def logged_url(self, url=None):
        """The host and port part, for logging."""
        url = url or self.current_url()
        return url[url.rindex('@') + 1:]

    def failover(self):
        """Call to fail-over to the next daemon URL.

        Returns False if there is only one, otherwise True.
        """
        if len(self.urls) > 1:
            self.url_index = (self.url_index + 1) % len(self.urls)
            self.logger.info(f'failing over to {self.logged_url()}')
            return True
        return False

    def client_session(self):
        """An aiohttp client session."""
        return aiohttp.ClientSession(connector=self.connector,
                                     connector_owner=False)

    async def _send_data(self, data):
        if not self.connector:
            raise asyncio.CancelledError(
                'Tried to send request during shutdown.')
        async with self.workqueue_semaphore:
            async with self.client_session() as session:
                async with session.post(self.current_url(), data=data) as resp:
                    kind = resp.headers.get('Content-Type', None)
                    if kind == 'application/json':
                        return await resp.json()
                    # bitcoind's HTTP protocol "handling" is a bad joke
                    text = await resp.text()
                    if 'Work queue depth exceeded' in text:
                        raise WorkQueueFullError
                    text = text.strip() or resp.reason
                    self.logger.error(text)
                    raise DaemonError(text)

    async def _send(self, payload, processor):
        """Send a payload to be converted to JSON.

        Handles temporary connection issues.  Daemon response errors
        are raise through DaemonError.
        """
        def log_error(error):
            nonlocal last_error_log, retry
            now = time.time()
            if now - last_error_log > 60:
                last_error_log = now
                self.logger.error(f'{error}  Retrying occasionally...')
            if retry == self.max_retry and self.failover():
                retry = 0

        on_good_message = None
        last_error_log = 0
        data = json.dumps(payload)
        retry = self.init_retry
        methods = tuple([payload['method']] if isinstance(payload, dict) else
                        [request['method'] for request in payload])
        while True:
            try:
                for method in methods:
                    self.lbrycrd_pending_count_metric.labels(
                        method=method).inc()
                result = await self._send_data(data)
                result = processor(result)
                if on_good_message:
                    self.logger.info(on_good_message)
                return result
            except asyncio.TimeoutError:
                log_error('timeout error.')
            except aiohttp.ServerDisconnectedError:
                log_error('disconnected.')
                on_good_message = 'connection restored'
            except aiohttp.ClientConnectionError:
                log_error('connection problem - is your daemon running?')
                on_good_message = 'connection restored'
            except aiohttp.ClientError as e:
                log_error(f'daemon error: {e}')
                on_good_message = 'running normally'
            except WarmingUpError:
                log_error('starting up checking blocks.')
                on_good_message = 'running normally'
            except WorkQueueFullError:
                log_error('work queue full.')
                on_good_message = 'running normally'
            finally:
                for method in methods:
                    self.lbrycrd_pending_count_metric.labels(
                        method=method).dec()
            await asyncio.sleep(retry)
            retry = max(min(self.max_retry, retry * 2), self.init_retry)

    async def _send_single(self, method, params=None):
        """Send a single request to the daemon."""

        start = time.perf_counter()

        def processor(result):
            err = result['error']
            if not err:
                return result['result']
            if err.get('code') == self.WARMING_UP:
                raise WarmingUpError
            raise DaemonError(err)

        payload = {'method': method, 'id': next(self.id_counter)}
        if params:
            payload['params'] = params
        result = await self._send(payload, processor)
        self.lbrycrd_request_time_metric.labels(
            method=method).observe(time.perf_counter() - start)
        return result

    async def _send_vector(self, method, params_iterable, replace_errs=False):
        """Send several requests of the same method.

        The result will be an array of the same length as params_iterable.
        If replace_errs is true, any item with an error is returned as None,
        otherwise an exception is raised."""

        start = time.perf_counter()

        def processor(result):
            errs = [item['error'] for item in result if item['error']]
            if any(err.get('code') == self.WARMING_UP for err in errs):
                raise WarmingUpError
            if not errs or replace_errs:
                return [item['result'] for item in result]
            raise DaemonError(errs)

        payload = [{
            'method': method,
            'params': p,
            'id': next(self.id_counter)
        } for p in params_iterable]
        result = []
        if payload:
            result = await self._send(payload, processor)
        self.lbrycrd_request_time_metric.labels(
            method=method).observe(time.perf_counter() - start)
        return result

    async def _is_rpc_available(self, method):
        """Return whether given RPC method is available in the daemon.

        Results are cached and the daemon will generally not be queried with
        the same method more than once."""
        available = self.available_rpcs.get(method)
        if available is None:
            available = True
            try:
                await self._send_single(method)
            except DaemonError as e:
                err = e.args[0]
                error_code = err.get("code")
                available = error_code != JSONRPC.METHOD_NOT_FOUND
            self.available_rpcs[method] = available
        return available

    async def block_hex_hashes(self, first, count):
        """Return the hex hashes of count block starting at height first."""
        if first + count < (self.cached_height() or 0) - 200:
            return await self._cached_block_hex_hashes(first, count)
        params_iterable = ((h, ) for h in range(first, first + count))
        return await self._send_vector('getblockhash', params_iterable)

    async def _cached_block_hex_hashes(self, first, count):
        """Return the hex hashes of count block starting at height first."""
        cached = self._block_hash_cache.get((first, count))
        if cached:
            return cached
        params_iterable = ((h, ) for h in range(first, first + count))
        self._block_hash_cache[(first, count)] = await self._send_vector(
            'getblockhash', params_iterable)
        return self._block_hash_cache[(first, count)]

    async def deserialised_block(self, hex_hash):
        """Return the deserialised block with the given hex hash."""
        if hex_hash not in self._block_cache:
            block = await self._send_single('getblock', (hex_hash, True))
            self._block_cache[hex_hash] = block
            return block
        return self._block_cache[hex_hash]

    async def raw_blocks(self, hex_hashes):
        """Return the raw binary blocks with the given hex hashes."""
        params_iterable = ((h, False) for h in hex_hashes)
        blocks = await self._send_vector('getblock', params_iterable)
        # Convert hex string to bytes
        return [hex_to_bytes(block) for block in blocks]

    async def mempool_hashes(self):
        """Update our record of the daemon's mempool hashes."""
        return await self._send_single('getrawmempool')

    async def estimatefee(self, block_count):
        """Return the fee estimate for the block count.  Units are whole
        currency units per KB, e.g. 0.00000995, or -1 if no estimate
        is available.
        """
        args = (block_count, )
        if await self._is_rpc_available('estimatesmartfee'):
            estimate = await self._send_single('estimatesmartfee', args)
            return estimate.get('feerate', -1)
        return await self._send_single('estimatefee', args)

    async def getnetworkinfo(self):
        """Return the result of the 'getnetworkinfo' RPC call."""
        return await self._send_single('getnetworkinfo')

    async def relayfee(self):
        """The minimum fee a low-priority tx must pay in order to be accepted
        to the daemon's memory pool."""
        network_info = await self.getnetworkinfo()
        return network_info['relayfee']

    async def getrawtransaction(self, hex_hash, verbose=False):
        """Return the serialized raw transaction with the given hash."""
        # Cast to int because some coin daemons are old and require it
        return await self._send_single('getrawtransaction',
                                       (hex_hash, int(verbose)))

    async def getrawtransactions(self, hex_hashes, replace_errs=True):
        """Return the serialized raw transactions with the given hashes.

        Replaces errors with None by default."""
        params_iterable = ((hex_hash, 0) for hex_hash in hex_hashes)
        txs = await self._send_vector('getrawtransaction',
                                      params_iterable,
                                      replace_errs=replace_errs)
        # Convert hex strings to bytes
        return [hex_to_bytes(tx) if tx else None for tx in txs]

    async def broadcast_transaction(self, raw_tx):
        """Broadcast a transaction to the network."""
        return await self._send_single('sendrawtransaction', (raw_tx, ))

    async def height(self):
        """Query the daemon for its current height."""
        self._height = await self._send_single('getblockcount')
        return self._height

    def cached_height(self):
        """Return the cached daemon height.

        If the daemon has not been queried yet this returns None."""
        return self._height
Exemple #13
0
class SPVServerStatusProtocol(asyncio.DatagramProtocol):

    def __init__(
        self, height: int, tip: bytes, country: str,
        throttle_cache_size: int = 1024, throttle_reqs_per_sec: int = 10,
        allow_localhost: bool = False, allow_lan: bool = False
    ):
        super().__init__()
        self.transport: Optional[asyncio.transports.DatagramTransport] = None
        self._height = height
        self._tip = tip
        self._flags = 0
        self._country = country
        self._left_cache = self._right_cache = None
        self.update_cached_response()
        self._throttle = LRUCache(throttle_cache_size)
        self._should_log = LRUCache(throttle_cache_size)
        self._min_delay = 1 / throttle_reqs_per_sec
        self._allow_localhost = allow_localhost
        self._allow_lan = allow_lan

    def update_cached_response(self):
        self._left_cache, self._right_cache = SPVPong.make_sans_source_address(
            self._flags, max(0, self._height), self._tip, self._country
        )

    def set_unavailable(self):
        self._flags &= 0b11111110
        self.update_cached_response()

    def set_available(self):
        self._flags |= 0b00000001
        self.update_cached_response()

    def set_height(self, height: int, tip: bytes):
        self._height, self._tip = height, tip
        self.update_cached_response()

    def should_throttle(self, host: str):
        now = perf_counter()
        last_requested = self._throttle.get(host, default=0)
        self._throttle[host] = now
        if now - last_requested < self._min_delay:
            log_cnt = self._should_log.get(host, default=0) + 1
            if log_cnt % 100 == 0:
                log.warning("throttle spv status to %s", host)
            self._should_log[host] = log_cnt
            return True
        return False

    def make_pong(self, host):
        return self._left_cache + SPVPong.encode_address(host) + self._right_cache

    def datagram_received(self, data: bytes, addr: Tuple[str, int]):
        if self.should_throttle(addr[0]):
            return
        try:
            SPVPing.decode(data)
        except (ValueError, struct.error, AttributeError, TypeError):
            # log.exception("derp")
            return
        if addr[1] >= 1024 and is_valid_public_ipv4(
                addr[0], allow_localhost=self._allow_localhost, allow_lan=self._allow_lan):
            self.transport.sendto(self.make_pong(addr[0]), addr)
        else:
            log.warning("odd packet from %s:%i", addr[0], addr[1])
        # ping_count_metric.inc()

    def connection_made(self, transport) -> None:
        self.transport = transport

    def connection_lost(self, exc: Optional[Exception]) -> None:
        self.transport = None

    def close(self):
        if self.transport:
            self.transport.close()