Ejemplo n.º 1
0
def get_pairings(cache: LRUCache, pks: List[G1Element], msgs: List[bytes],
                 force_cache: bool) -> List[GTElement]:
    pairings: List[Optional[GTElement]] = []
    missing_count: int = 0
    for pk, msg in zip(pks, msgs):
        aug_msg: bytes = bytes(pk) + msg
        h: bytes = bytes(std_hash(aug_msg))
        pairing: Optional[GTElement] = cache.get(h)
        if not force_cache and pairing is None:
            missing_count += 1
            # Heuristic to avoid more expensive sig validation with pairing
            # cache when it's empty and cached pairings won't be useful later
            # (e.g. while syncing)
            if missing_count > len(pks) // 2:
                return []
        pairings.append(pairing)

    for i, pairing in enumerate(pairings):
        if pairing is None:
            aug_msg = bytes(pks[i]) + msgs[i]
            aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg)
            pairing = pks[i].pair(aug_hash)

            h = bytes(std_hash(aug_msg))
            cache.put(h, pairing)
            pairings[i] = pairing

    return pairings
Ejemplo n.º 2
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache

    @classmethod
    async def create(cls, connection: aiosqlite.Connection):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db = connection
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
            "  is_block tinyint, is_fully_compactified tinyint, block blob)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint,"
            "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
        )

        # todo remove in v1.2
        await self.db.execute("DROP TABLE IF EXISTS sub_epoch_segments")

        # Sub epoch segments for weight proofs
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v2(ses_height bigint PRIMARY KEY, challenge_segments blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)"
        )
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on full_blocks(is_block)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on block_records(is_block)")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        return self

    async def begin_transaction(self) -> None:
        # Also locks the coin store, since both stores must be updated at once
        cursor = await self.db.execute("BEGIN TRANSACTION")
        await cursor.close()

    async def commit_transaction(self) -> None:
        await self.db.commit()

    async def rollback_transaction(self) -> None:
        # Also rolls back the coin store, since both stores must be updated at once
        cursor = await self.db.execute("ROLLBACK")
        await cursor.close()

    async def add_full_block(self, block: FullBlock,
                             block_record: BlockRecord) -> None:
        self.block_cache.put(block.header_hash, block)
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
            (
                block.header_hash.hex(),
                block.height,
                int(block.is_transaction_block()),
                int(block.is_fully_compactified()),
                bytes(block),
            ),
        )

        await cursor_1.close()

        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
            (
                block.header_hash.hex(),
                block.prev_header_hash.hex(),
                block.height,
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
                block.is_transaction_block(),
            ),
        )
        await cursor_2.close()
        await self.db.commit()

    async def persist_sub_epoch_challenge_segments(
            self, sub_epoch_summary_height: uint32,
            segments: List[SubEpochChallengeSegment]) -> None:
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO sub_epoch_segments_v2 VALUES(?, ?)",
            (sub_epoch_summary_height, bytes(SubEpochSegments(segments))),
        )
        await cursor_1.close()

    async def get_sub_epoch_challenge_segments(
        self,
        sub_epoch_summary_height: uint32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cursor = await self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v2 WHERE ses_height=?",
            (sub_epoch_summary_height, ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return SubEpochSegments.from_bytes(row[0]).challenge_segments
        return None

    async def delete_sub_epoch_challenge_segments(self,
                                                  fork_height: uint32) -> None:
        cursor = await self.db.execute(
            "delete from sub_epoch_segments_v2 WHERE ses_height>?",
            (fork_height, ))
        await cursor.close()

    async def get_full_block(self,
                             header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return FullBlock.from_bytes(row[0])
        return None

    async def get_full_blocks_at(self,
                                 heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [FullBlock.from_bytes(row[0]) for row in rows]

    async def get_block_records_at(self,
                                   heights: List[uint32]) -> List[BlockRecord]:
        if len(heights) == 0:
            return []
        heights_db = tuple(heights)
        formatted_str = (
            f'SELECT block from block_records WHERE height in ({"?," * (len(heights_db) - 1)}?) ORDER BY height ASC;'
        )
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [BlockRecord.from_bytes(row[0]) for row in rows]

    async def get_blocks_by_hash(
            self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = f'SELECT block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, FullBlock] = {}
        for row in rows:
            full_block: FullBlock = FullBlock.from_bytes(row[0])
            all_blocks[full_block.header_hash] = full_block
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_header_blocks_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, HeaderBlock]:

        formatted_str = f"SELECT header_hash,block from full_blocks WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, HeaderBlock] = {}
        for row in rows:
            # Ugly hack, until full_block.get_block_header is rewritten as part of generator runner change
            await asyncio.sleep(0.001)
            header_hash = bytes.fromhex(row[0])
            full_block: FullBlock = FullBlock.from_bytes(row[1])
            ret[header_hash] = full_block.get_block_header()

        return ret

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records(
        self, ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """
        cursor = await self.db.execute("SELECT * from block_records")
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        peak: Optional[bytes32] = None
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[3])
            if row[5]:
                assert peak is None  # Sanity check, only one peak
                peak = header_hash
        return ret, peak

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        peak_row = await res.fetchone()
        await res.close()
        if peak_row is None:
            return {}, None

        formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak_row[2] - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])
        return ret, bytes.fromhex(peak_row[0])

    async def get_peak_height_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def is_fully_compactified(self,
                                    header_hash: bytes32) -> Optional[bool]:
        cursor = await self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return bool(row[0])

    async def get_first_not_compactified(self,
                                         min_height: int) -> Optional[int]:
        cursor = await self.db.execute(
            "SELECT MIN(height) from full_blocks WHERE is_fully_compactified=0 AND height>=?",
            (min_height, ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return int(row[0])
Ejemplo n.º 3
0
class CoinStore:
    """
    This object handles CoinRecords in DB.
    A cache is maintained for quicker access to recent coins.
    """

    coin_record_db: aiosqlite.Connection
    coin_record_cache: LRUCache
    cache_size: uint32
    db_wrapper: DBWrapper

    @classmethod
    async def create(cls, db_wrapper: DBWrapper, cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db
        await self.coin_record_db.execute("pragma journal_mode=wal")
        await self.coin_record_db.execute("pragma synchronous=2")
        await self.coin_record_db.execute(
            (
                "CREATE TABLE IF NOT EXISTS coin_record("
                "coin_name text PRIMARY KEY,"
                " confirmed_index bigint,"
                " spent_index bigint,"
                " spent int,"
                " coinbase int,"
                " puzzle_hash text,"
                " coin_parent text,"
                " amount blob,"
                " timestamp bigint)"
            )
        )

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)")

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent on coin_record(spent)")

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)")

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self

    async def new_block(self, block: FullBlock, tx_additions: List[Coin], tx_removals: List[bytes32]):
        """
        Only called for blocks which are blocks (and thus have rewards and transactions)
        """
        if block.is_transaction_block() is False:
            return None
        assert block.foliage_transaction_block is not None

        for coin in tx_additions:
            record: CoinRecord = CoinRecord(
                coin,
                block.height,
                uint32(0),
                False,
                False,
                block.foliage_transaction_block.timestamp,
            )
            await self._add_coin_record(record, False)

        included_reward_coins = block.get_included_reward_coins()
        if block.height == 0:
            assert len(included_reward_coins) == 0
        else:
            assert len(included_reward_coins) >= 2

        for coin in included_reward_coins:
            reward_coin_r: CoinRecord = CoinRecord(
                coin,
                block.height,
                uint32(0),
                False,
                True,
                block.foliage_transaction_block.timestamp,
            )
            await self._add_coin_record(reward_coin_r, False)

        total_amount_spent: int = 0
        for coin_name in tx_removals:
            total_amount_spent += await self._set_spent(coin_name, block.height)

        # Sanity check, already checked in block_body_validation
        assert sum([a.amount for a in tx_additions]) <= total_amount_spent

    # Checks DB and DiffStores for CoinRecord with coin_name and returns it
    async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]:
        cached = self.coin_record_cache.get(coin_name)
        if cached is not None:
            return cached
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            record = CoinRecord(coin, row[1], row[2], row[3], row[4], row[8])
            self.coin_record_cache.put(record.coin.name(), record)
            return record
        return None

    async def get_coins_added_at_height(self, height: uint32) -> List[CoinRecord]:
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE confirmed_index=?", (height,))
        rows = await cursor.fetchall()
        await cursor.close()
        coins = []
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.append(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return coins

    async def get_coins_removed_at_height(self, height: uint32) -> List[CoinRecord]:
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE spent_index=?", (height,))
        rows = await cursor.fetchall()
        await cursor.close()
        coins = []
        for row in rows:
            spent: bool = bool(row[3])
            if spent:
                coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
                coin_record = CoinRecord(coin, row[1], row[2], spent, row[4], row[8])
                coins.append(coin_record)
        return coins

    # Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
    async def get_coin_records_by_puzzle_hash(
        self,
        include_spent_coins: bool,
        puzzle_hash: bytes32,
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:

        coins = set()
        cursor = await self.coin_record_db.execute(
            f"SELECT * from coin_record INDEXED BY coin_puzzle_hash WHERE puzzle_hash=? "
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            (puzzle_hash.hex(), start_height, end_height),
        )
        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_puzzle_hashes(
        self,
        include_spent_coins: bool,
        puzzle_hashes: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        cursor = await self.coin_record_db.execute(
            f"SELECT * from coin_record INDEXED BY coin_puzzle_hash "
            f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            puzzle_hashes_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_names(
        self,
        include_spent_coins: bool,
        names: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(names) == 0:
            return []

        coins = set()
        names_db = tuple([name.hex() for name in names])
        cursor = await self.coin_record_db.execute(
            f'SELECT * from coin_record WHERE coin_name in ({"?," * (len(names_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            names_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_parent_ids(
        self,
        include_spent_coins: bool,
        parent_ids: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(parent_ids) == 0:
            return []

        coins = set()
        parent_ids_db = tuple([pid.hex() for pid in parent_ids])
        cursor = await self.coin_record_db.execute(
            f'SELECT * from coin_record WHERE coin_parent in ({"?," * (len(parent_ids_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            parent_ids_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def rollback_to_block(self, block_index: int):
        """
        Note that block_index can be negative, in which case everything is rolled back
        """
        # Update memory cache
        delete_queue: bytes32 = []
        for coin_name, coin_record in list(self.coin_record_cache.cache.items()):
            if int(coin_record.spent_block_index) > block_index:
                new_record = CoinRecord(
                    coin_record.coin,
                    coin_record.confirmed_block_index,
                    uint32(0),
                    False,
                    coin_record.coinbase,
                    coin_record.timestamp,
                )
                self.coin_record_cache.put(coin_record.coin.name(), new_record)
            if int(coin_record.confirmed_block_index) > block_index:
                delete_queue.append(coin_name)

        for coin_name in delete_queue:
            self.coin_record_cache.remove(coin_name)

        # Delete from storage
        c1 = await self.coin_record_db.execute("DELETE FROM coin_record WHERE confirmed_index>?", (block_index,))
        await c1.close()
        c2 = await self.coin_record_db.execute(
            "UPDATE coin_record SET spent_index = 0, spent = 0 WHERE spent_index>?",
            (block_index,),
        )
        await c2.close()

    # Store CoinRecord in DB and ram cache
    async def _add_coin_record(self, record: CoinRecord, allow_replace: bool) -> None:
        if self.coin_record_cache.get(record.coin.name()) is not None:
            self.coin_record_cache.remove(record.coin.name())

        cursor = await self.coin_record_db.execute(
            f"INSERT {'OR REPLACE ' if allow_replace else ''}INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (
                record.coin.name().hex(),
                record.confirmed_block_index,
                record.spent_block_index,
                int(record.spent),
                int(record.coinbase),
                str(record.coin.puzzle_hash.hex()),
                str(record.coin.parent_coin_info.hex()),
                bytes(record.coin.amount),
                record.timestamp,
            ),
        )
        await cursor.close()

    # Update coin_record to be spent in DB
    async def _set_spent(self, coin_name: bytes32, index: uint32) -> uint64:
        current: Optional[CoinRecord] = await self.get_coin_record(coin_name)
        if current is None:
            raise ValueError(f"Cannot spend a coin that does not exist in db: {coin_name}")

        assert not current.spent  # Redundant sanity check, already checked in block_body_validation
        spent: CoinRecord = CoinRecord(
            current.coin,
            current.confirmed_block_index,
            index,
            True,
            current.coinbase,
            current.timestamp,
        )  # type: ignore # noqa
        await self._add_coin_record(spent, True)
        return current.coin.amount
Ejemplo n.º 4
0
class CoinStore:
    """
    This object handles CoinRecords in DB.
    A cache is maintained for quicker access to recent coins.
    """

    coin_record_db: aiosqlite.Connection
    coin_record_cache: LRUCache
    cache_size: uint32
    db_wrapper: DBWrapper

    @classmethod
    async def create(cls,
                     db_wrapper: DBWrapper,
                     cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                "CREATE TABLE IF NOT EXISTS coin_record("
                "coin_name blob PRIMARY KEY,"
                " confirmed_index bigint,"
                " spent_index bigint,"  # if this is zero, it means the coin has not been spent
                " coinbase int,"
                " puzzle_hash blob,"
                " coin_parent blob,"
                " amount blob,"  # we use a blob of 8 bytes to store uint64
                " timestamp bigint)")

        else:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                ("CREATE TABLE IF NOT EXISTS coin_record("
                 "coin_name text PRIMARY KEY,"
                 " confirmed_index bigint,"
                 " spent_index bigint,"
                 " spent int,"
                 " coinbase int,"
                 " puzzle_hash text,"
                 " coin_parent text,"
                 " amount blob,"
                 " timestamp bigint)"))

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)"
        )

        if self.db_wrapper.allow_upgrades:
            await self.coin_record_db.execute("DROP INDEX IF EXISTS coin_spent"
                                              )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_parent_index on coin_record(coin_parent)"
        )

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self

    def maybe_from_hex(self, field: Any) -> bytes:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return bytes.fromhex(field)

    def maybe_to_hex(self, field: bytes) -> Any:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return field.hex()

    async def new_block(
        self,
        height: uint32,
        timestamp: uint64,
        included_reward_coins: Set[Coin],
        tx_additions: List[Coin],
        tx_removals: List[bytes32],
    ) -> List[CoinRecord]:
        """
        Only called for blocks which are blocks (and thus have rewards and transactions)
        Returns a list of the CoinRecords that were added by this block
        """

        start = time()

        additions = []

        for coin in tx_additions:
            record: CoinRecord = CoinRecord(
                coin,
                height,
                uint32(0),
                False,
                timestamp,
            )
            additions.append(record)

        if height == 0:
            assert len(included_reward_coins) == 0
        else:
            assert len(included_reward_coins) >= 2

        for coin in included_reward_coins:
            reward_coin_r: CoinRecord = CoinRecord(
                coin,
                height,
                uint32(0),
                True,
                timestamp,
            )
            additions.append(reward_coin_r)

        await self._add_coin_records(additions)
        await self._set_spent(tx_removals, height)

        end = time()
        log.log(
            logging.WARNING if end - start > 10 else logging.DEBUG,
            f"It took {end - start:0.2f}s to apply {len(tx_additions)} additions and "
            + f"{len(tx_removals)} removals to the coin store. Make sure " +
            "blockchain database is on a fast drive",
        )

        return additions

    # Checks DB and DiffStores for CoinRecord with coin_name and returns it
    async def get_coin_record(self,
                              coin_name: bytes32) -> Optional[CoinRecord]:
        cached = self.coin_record_cache.get(coin_name)
        if cached is not None:
            return cached

        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE coin_name=?",
            (self.maybe_to_hex(coin_name), ),
        ) as cursor:
            row = await cursor.fetchone()
            if row is not None:
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, row[0], row[1], row[2], row[6])
                self.coin_record_cache.put(record.coin.name(), record)
                return record
        return None

    async def get_coins_added_at_height(self,
                                        height: uint32) -> List[CoinRecord]:
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index=?",
            (height, ),
        ) as cursor:
            rows = await cursor.fetchall()
            coins = []
            for row in rows:
                coin = self.row_to_coin(row)
                coins.append(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return coins

    async def get_coins_removed_at_height(self,
                                          height: uint32) -> List[CoinRecord]:
        # Special case to avoid querying all unspent coins (spent_index=0)
        if height == 0:
            return []
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE spent_index=?",
            (height, ),
        ) as cursor:
            coins = []
            for row in await cursor.fetchall():
                if row[1] != 0:
                    coin = self.row_to_coin(row)
                    coin_record = CoinRecord(coin, row[0], row[1], row[2],
                                             row[6])
                    coins.append(coin_record)
            return coins

    # Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
    async def get_coin_records_by_puzzle_hash(
            self,
            include_spent_coins: bool,
            puzzle_hash: bytes32,
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:

        coins = set()

        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash WHERE puzzle_hash=? "
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
            (self.maybe_to_hex(puzzle_hash), start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_records_by_puzzle_hashes(
            self,
            include_spent_coins: bool,
            puzzle_hashes: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            puzzle_hashes_db = tuple(puzzle_hashes)
        else:
            puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
                f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                puzzle_hashes_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_records_by_names(
            self,
            include_spent_coins: bool,
            names: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(names) == 0:
            return []

        coins = set()
        names_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            names_db = tuple(names)
        else:
            names_db = tuple([name.hex() for name in names])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_name in ({"?," * (len(names) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                names_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))

        return list(coins)

    def row_to_coin(self, row) -> Coin:
        return Coin(bytes32(self.maybe_from_hex(row[4])),
                    bytes32(self.maybe_from_hex(row[3])),
                    uint64.from_bytes(row[5]))

    def row_to_coin_state(self, row):
        coin = self.row_to_coin(row)
        spent_h = None
        if row[1] != 0:
            spent_h = row[1]
        return CoinState(coin, spent_h, row[0])

    async def get_coin_states_by_puzzle_hashes(
            self,
            include_spent_coins: bool,
            puzzle_hashes: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinState]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            puzzle_hashes_db = tuple(puzzle_hashes)
        else:
            puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
                f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                puzzle_hashes_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coins.add(self.row_to_coin_state(row))

            return list(coins)

    async def get_coin_records_by_parent_ids(
            self,
            include_spent_coins: bool,
            parent_ids: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(parent_ids) == 0:
            return []

        coins = set()
        parent_ids_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            parent_ids_db = tuple(parent_ids)
        else:
            parent_ids_db = tuple([pid.hex() for pid in parent_ids])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_parent in ({"?," * (len(parent_ids) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                parent_ids_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_state_by_ids(
            self,
            include_spent_coins: bool,
            coin_ids: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinState]:
        if len(coin_ids) == 0:
            return []

        coins = set()
        coin_ids_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            coin_ids_db = tuple(coin_ids)
        else:
            coin_ids_db = tuple([pid.hex() for pid in coin_ids])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_name in ({"?," * (len(coin_ids) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                coin_ids_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coins.add(self.row_to_coin_state(row))
            return list(coins)

    async def rollback_to_block(self, block_index: int) -> List[CoinRecord]:
        """
        Note that block_index can be negative, in which case everything is rolled back
        Returns the list of coin records that have been modified
        """
        # Update memory cache
        delete_queue: List[bytes32] = []
        for coin_name, coin_record in list(
                self.coin_record_cache.cache.items()):
            if int(coin_record.spent_block_index) > block_index:
                new_record = CoinRecord(
                    coin_record.coin,
                    coin_record.confirmed_block_index,
                    uint32(0),
                    coin_record.coinbase,
                    coin_record.timestamp,
                )
                self.coin_record_cache.put(coin_record.coin.name(), new_record)
            if int(coin_record.confirmed_block_index) > block_index:
                delete_queue.append(coin_name)

        for coin_name in delete_queue:
            self.coin_record_cache.remove(coin_name)

        coin_changes: Dict[bytes32, CoinRecord] = {}
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index>?",
            (block_index, ),
        ) as cursor:
            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, uint32(0), row[1], row[2], uint64(0))
                coin_changes[record.name] = record

        # Delete from storage
        await self.coin_record_db.execute(
            "DELETE FROM coin_record WHERE confirmed_index>?", (block_index, ))

        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index>?",
            (block_index, ),
        ) as cursor:
            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, row[0], uint32(0), row[2], row[6])
                if record.name not in coin_changes:
                    coin_changes[record.name] = record

        if self.db_wrapper.db_version == 2:
            await self.coin_record_db.execute(
                "UPDATE coin_record SET spent_index=0 WHERE spent_index>?",
                (block_index, ))
        else:
            await self.coin_record_db.execute(
                "UPDATE coin_record SET spent_index = 0, spent = 0 WHERE spent_index>?",
                (block_index, ))
        return list(coin_changes.values())

    # Store CoinRecord in DB and ram cache
    async def _add_coin_records(self, records: List[CoinRecord]) -> None:

        if self.db_wrapper.db_version == 2:
            values2 = []
            for record in records:
                self.coin_record_cache.put(record.coin.name(), record)
                values2.append((
                    record.coin.name(),
                    record.confirmed_block_index,
                    record.spent_block_index,
                    int(record.coinbase),
                    record.coin.puzzle_hash,
                    record.coin.parent_coin_info,
                    bytes(record.coin.amount),
                    record.timestamp,
                ))
            await self.coin_record_db.executemany(
                "INSERT INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
                values2,
            )
        else:
            values = []
            for record in records:
                self.coin_record_cache.put(record.coin.name(), record)
                values.append((
                    record.coin.name().hex(),
                    record.confirmed_block_index,
                    record.spent_block_index,
                    int(record.spent),
                    int(record.coinbase),
                    record.coin.puzzle_hash.hex(),
                    record.coin.parent_coin_info.hex(),
                    bytes(record.coin.amount),
                    record.timestamp,
                ))
            await self.coin_record_db.executemany(
                "INSERT INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)",
                values,
            )

    # Update coin_record to be spent in DB
    async def _set_spent(self, coin_names: List[bytes32], index: uint32):

        assert len(coin_names) == 0 or index > 0
        # if this coin is in the cache, mark it as spent in there
        updates = []
        for coin_name in coin_names:
            r = self.coin_record_cache.get(coin_name)
            if r is not None:
                self.coin_record_cache.put(
                    r.name,
                    CoinRecord(r.coin, r.confirmed_block_index, index,
                               r.coinbase, r.timestamp))
            updates.append((index, self.maybe_to_hex(coin_name)))

        if self.db_wrapper.db_version == 2:
            await self.coin_record_db.executemany(
                "UPDATE OR FAIL coin_record SET spent_index=? WHERE coin_name=?",
                updates)
        else:
            await self.coin_record_db.executemany(
                "UPDATE OR FAIL coin_record SET spent=1,spent_index=? WHERE coin_name=?",
                updates)
Ejemplo n.º 5
0
class WalletBlockStore:
    """
    This object handles HeaderBlocks and Blocks stored in DB used by wallet.
    """

    db: aiosqlite.Connection
    db_wrapper: DBWrapper
    block_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute("pragma journal_mode=wal")
        await self.db.execute("pragma synchronous=2")

        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS header_blocks(header_hash text PRIMARY KEY, height int,"
            " timestamp int, block blob)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS header_hash on header_blocks(header_hash)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS timestamp on header_blocks(timestamp)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on header_blocks(height)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint, weight bigint, total_iters text,"
            "block blob, sub_epoch_summary blob, is_peak tinyint)")

        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS additional_coin_spends(header_hash text PRIMARY KEY, spends_list_blob blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.commit()
        self.block_cache = LRUCache(1000)
        return self

    async def _clear_database(self):
        cursor_2 = await self.db.execute("DELETE FROM header_blocks")
        await cursor_2.close()
        await self.db.commit()

    async def add_block_record(
        self,
        header_block_record: HeaderBlockRecord,
        block_record: BlockRecord,
        additional_coin_spends: List[CoinSolution],
    ):
        """
        Adds a block record to the database. This block record is assumed to be connected
        to the chain, but it may or may not be in the LCA path.
        """
        cached = self.block_cache.get(header_block_record.header_hash)
        if cached is not None:
            # Since write to db can fail, we remove from cache here to avoid potential inconsistency
            # Adding to cache only from reading
            self.block_cache.put(header_block_record.header_hash, None)

        if header_block_record.header.foliage_transaction_block is not None:
            timestamp = header_block_record.header.foliage_transaction_block.timestamp
        else:
            timestamp = uint64(0)
        cursor = await self.db.execute(
            "INSERT OR REPLACE INTO header_blocks VALUES(?, ?, ?, ?)",
            (
                header_block_record.header_hash.hex(),
                header_block_record.height,
                timestamp,
                bytes(header_block_record),
            ),
        )

        await cursor.close()
        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?, ?, ?, ?,?)",
            (
                header_block_record.header.header_hash.hex(),
                header_block_record.header.prev_header_hash.hex(),
                header_block_record.header.height,
                header_block_record.header.weight.to_bytes(
                    128 // 8, "big", signed=False).hex(),
                header_block_record.header.total_iters.to_bytes(
                    128 // 8, "big", signed=False).hex(),
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
            ),
        )
        await cursor_2.close()

        if len(additional_coin_spends) > 0:
            blob: bytes = bytes(AdditionalCoinSpends(additional_coin_spends))
            cursor_3 = await self.db.execute(
                "INSERT OR REPLACE INTO additional_coin_spends VALUES(?, ?)",
                (header_block_record.header_hash, blob))
            await cursor_3.close()

    async def get_header_block_at(self,
                                  heights: List[uint32]) -> List[HeaderBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from header_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [HeaderBlock.from_bytes(row[0]) for row in rows]

    async def get_header_block_record(
            self, header_hash: bytes32) -> Optional[HeaderBlockRecord]:
        """Gets a block record from the database, if present"""
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT block from header_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            hbr: HeaderBlockRecord = HeaderBlockRecord.from_bytes(row[0])
            self.block_cache.put(hbr.header_hash, hbr)
            return hbr
        else:
            return None

    async def get_additional_coin_spends(
            self, header_hash: bytes32) -> Optional[List[CoinSolution]]:
        cursor = await self.db.execute(
            "SELECT spends_list_blob from additional_coin_spends WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            coin_spends: AdditionalCoinSpends = AdditionalCoinSpends.from_bytes(
                row[0])
            return coin_spends.coin_spends_list
        else:
            return None

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records(
        self, ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """
        cursor = await self.db.execute(
            "SELECT header_hash, block, is_peak from block_records")
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        peak: Optional[bytes32] = None
        for row in rows:
            header_hash_bytes, block_record_bytes, is_peak = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)
            if is_peak:
                assert peak is None  # Sanity check, only one peak
                peak = header_hash
        return ret, peak

    def rollback_cache_block(self, header_hash: bytes32):
        self.block_cache.remove(header_hash)

    async def set_peak(self, header_hash: bytes32) -> None:
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT header_hash, height from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, None
        header_hash_bytes, peak_height = row
        peak: bytes32 = bytes32(bytes.fromhex(header_hash_bytes))

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {peak_height - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)
        return ret, peak

    async def get_header_blocks_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, HeaderBlock]:

        formatted_str = f"SELECT header_hash, block from header_blocks WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, HeaderBlock] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = HeaderBlock.from_bytes(block_record_bytes)

        return ret

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)

        return ret

    async def get_peak_heights_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT header_hash from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries
Ejemplo n.º 6
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache
    db_wrapper: DBWrapper
    ses_challenge_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # TODO: most data in block is duplicated in block_record. The only
            # reason for this is that our parsing of a FullBlock is so slow,
            # it's faster to store duplicate data to parse less when we just
            # need the BlockRecord. Once we fix the parsing (and data structure)
            # of FullBlock, this can use less space
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks("
                "header_hash blob PRIMARY KEY,"
                "prev_hash blob,"
                "height bigint,"
                "sub_epoch_summary blob,"
                "is_fully_compactified tinyint,"
                "in_main_chain tinyint,"
                "block blob,"
                "block_record blob)"
            )

            # This is a single-row table containing the hash of the current
            # peak. The "key" field is there to make update statements simple
            await self.db.execute("CREATE TABLE IF NOT EXISTS current_peak(key int PRIMARY KEY, hash blob)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on full_blocks(height)")

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3("
                "ses_block_hash blob PRIMARY KEY,"
                "challenge_segments blob)"
            )

            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified ON"
                " full_blocks(is_fully_compactified, in_main_chain) WHERE in_main_chain=1"
            )
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS main_chain ON full_blocks(height, in_main_chain) WHERE in_main_chain=1"
            )

        else:

            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
                "  is_block tinyint, is_fully_compactified tinyint, block blob)"
            )

            # Block records
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS block_records(header_hash "
                "text PRIMARY KEY, prev_hash text, height bigint,"
                "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
            )

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY,"
                "challenge_segments blob)"
            )

            # Height index so we can look up in order of height for sync purposes
            await self.db.execute("CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)")
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
            )

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on block_records(height)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS peak on block_records(is_peak) where is_peak = 1")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self

    def maybe_from_hex(self, field: Any) -> bytes:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return bytes.fromhex(field)

    def maybe_to_hex(self, field: bytes) -> Any:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return field.hex()

    def compress(self, block: FullBlock) -> bytes:
        return zstd.compress(bytes(block))

    def maybe_decompress(self, block_bytes: bytes) -> FullBlock:
        if self.db_wrapper.db_version == 2:
            return FullBlock.from_bytes(zstd.decompress(block_bytes))
        else:
            return FullBlock.from_bytes(block_bytes)

    async def rollback(self, height: int) -> None:
        if self.db_wrapper.db_version == 2:
            await self.db.execute(
                "UPDATE OR FAIL full_blocks SET in_main_chain=0 WHERE height>? AND in_main_chain=1", (height,)
            )

    async def set_in_chain(self, header_hashes: List[Tuple[bytes32]]) -> None:
        if self.db_wrapper.db_version == 2:
            await self.db.executemany(
                "UPDATE OR FAIL full_blocks SET in_main_chain=1 WHERE header_hash=?", header_hashes
            )

    async def add_full_block(
        self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord, in_main_chain: bool
    ) -> None:
        self.block_cache.put(header_hash, block)

        if self.db_wrapper.db_version == 2:

            ses: Optional[bytes] = (
                None
                if block_record.sub_epoch_summary_included is None
                else bytes(block_record.sub_epoch_summary_included)
            )

            await self.db.execute(
                "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    header_hash,
                    block.prev_header_hash,
                    block.height,
                    ses,
                    int(block.is_fully_compactified()),
                    in_main_chain,  # in_main_chain
                    self.compress(block),
                    bytes(block_record),
                ),
            )

        else:
            await self.db.execute(
                "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
                (
                    header_hash.hex(),
                    block.height,
                    int(block.is_transaction_block()),
                    int(block.is_fully_compactified()),
                    bytes(block),
                ),
            )

            await self.db.execute(
                "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
                (
                    header_hash.hex(),
                    block.prev_header_hash.hex(),
                    block.height,
                    bytes(block_record),
                    None
                    if block_record.sub_epoch_summary_included is None
                    else bytes(block_record.sub_epoch_summary_included),
                    False,
                    block.is_transaction_block(),
                ),
            )

    async def persist_sub_epoch_challenge_segments(
        self, ses_block_hash: bytes32, segments: List[SubEpochChallengeSegment]
    ) -> None:
        async with self.db_wrapper.lock:
            await self.db.execute(
                "INSERT OR REPLACE INTO sub_epoch_segments_v3 VALUES(?, ?)",
                (self.maybe_to_hex(ses_block_hash), bytes(SubEpochSegments(segments))),
            )
            await self.db.commit()

    async def get_sub_epoch_challenge_segments(
        self,
        ses_block_hash: bytes32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cached = self.ses_challenge_cache.get(ses_block_hash)
        if cached is not None:
            return cached

        async with self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v3 WHERE ses_block_hash=?",
            (self.maybe_to_hex(ses_block_hash),),
        ) as cursor:
            row = await cursor.fetchone()

        if row is not None:
            challenge_segments = SubEpochSegments.from_bytes(row[0]).challenge_segments
            self.ses_challenge_cache.put(ses_block_hash, challenge_segments)
            return challenge_segments
        return None

    def rollback_cache_block(self, header_hash: bytes32):
        try:
            self.block_cache.remove(header_hash)
        except KeyError:
            # this is best effort. When rolling back, we may not have added the
            # block to the cache yet
            pass

    async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return cached
        log.debug(f"cache miss for block {header_hash.hex()}")
        async with self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is not None:
            block = self.maybe_decompress(row[0])
            self.block_cache.put(header_hash, block)
            return block
        return None

    async def get_full_block_bytes(self, header_hash: bytes32) -> Optional[bytes]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return bytes(cached)
        log.debug(f"cache miss for block {header_hash.hex()}")
        async with self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is not None:
            if self.db_wrapper.db_version == 2:
                return zstd.decompress(row[0])
            else:
                return row[0]

        return None

    async def get_full_blocks_at(self, heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        async with self.db.execute(formatted_str, heights_db) as cursor:
            ret: List[FullBlock] = []
            for row in await cursor.fetchall():
                ret.append(self.maybe_decompress(row[0]))
            return ret

    async def get_block_records_by_hash(self, header_hashes: List[bytes32]):
        """
        Returns a list of Block Records, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """
        if len(header_hashes) == 0:
            return []

        all_blocks: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:
            async with self.db.execute(
                "SELECT header_hash,block_record FROM full_blocks "
                f'WHERE header_hash in ({"?," * (len(header_hashes) - 1)}?)',
                tuple(header_hashes),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    all_blocks[header_hash] = BlockRecord.from_bytes(row[1])
        else:
            formatted_str = f'SELECT block from block_records WHERE header_hash in ({"?," * (len(header_hashes) - 1)}?)'
            async with self.db.execute(formatted_str, tuple([hh.hex() for hh in header_hashes])) as cursor:
                for row in await cursor.fetchall():
                    block_rec: BlockRecord = BlockRecord.from_bytes(row[0])
                    all_blocks[block_rec.header_hash] = block_rec

        ret: List[BlockRecord] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_blocks_by_hash(self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            header_hashes_db = tuple(header_hashes)
        else:
            header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = (
            f'SELECT header_hash, block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        )
        all_blocks: Dict[bytes32, FullBlock] = {}
        async with self.db.execute(formatted_str, header_hashes_db) as cursor:
            for row in await cursor.fetchall():
                header_hash = self.maybe_from_hex(row[0])
                full_block: FullBlock = self.maybe_decompress(row[1])
                # TODO: address hint error and remove ignore
                #       error: Invalid index type "bytes" for "Dict[bytes32, FullBlock]";
                # expected type "bytes32"  [index]
                all_blocks[header_hash] = full_block  # type: ignore[index]
                self.block_cache.put(header_hash, full_block)
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:

        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT block_record FROM full_blocks WHERE header_hash=?",
                (header_hash,),
            ) as cursor:
                row = await cursor.fetchone()
            if row is not None:
                return BlockRecord.from_bytes(row[0])

        else:
            async with self.db.execute(
                "SELECT block from block_records WHERE header_hash=?",
                (header_hash.hex(),),
            ) as cursor:
                row = await cursor.fetchone()
            if row is not None:
                return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        ret: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND height <= ?",
                (start, stop),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        else:

            formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

            async with await self.db.execute(formatted_str) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(self.maybe_from_hex(row[0]))
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_peak(self) -> Optional[Tuple[bytes32, uint32]]:

        if self.db_wrapper.db_version == 2:
            async with self.db.execute("SELECT hash FROM current_peak WHERE key = 0") as cursor:
                peak_row = await cursor.fetchone()
            if peak_row is None:
                return None
            async with self.db.execute("SELECT height FROM full_blocks WHERE header_hash=?", (peak_row[0],)) as cursor:
                peak_height = await cursor.fetchone()
            if peak_height is None:
                return None
            return bytes32(peak_row[0]), uint32(peak_height[0])
        else:
            async with self.db.execute("SELECT header_hash, height from block_records WHERE is_peak = 1") as cursor:
                peak_row = await cursor.fetchone()
            if peak_row is None:
                return None
            return bytes32(bytes.fromhex(peak_row[0])), uint32(peak_row[1])

    async def get_block_records_close_to_peak(
        self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        peak = await self.get_peak()
        if peak is None:
            return {}, None

        ret: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT header_hash, block_record FROM full_blocks WHERE height >= ?",
                (peak[1] - blocks_n,),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        else:
            formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak[1] - blocks_n}"
            async with self.db.execute(formatted_str) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(self.maybe_from_hex(row[0]))
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret, peak[0]

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store

        if self.db_wrapper.db_version == 2:
            # Note: we use the key field as 0 just to ensure all inserts replace the existing row
            await self.db.execute("INSERT OR REPLACE INTO current_peak VALUES(?, ?)", (0, header_hash))
        else:
            await self.db.execute("UPDATE block_records SET is_peak=0 WHERE is_peak=1")
            await self.db.execute(
                "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
                (self.maybe_to_hex(header_hash),),
            )

    async def is_fully_compactified(self, header_hash: bytes32) -> Optional[bool]:
        async with self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is None:
            return None
        return bool(row[0])

    async def get_random_not_compactified(self, number: int) -> List[int]:

        if self.db_wrapper.db_version == 2:
            async with self.db.execute(
                f"SELECT height FROM full_blocks WHERE in_main_chain=1 AND is_fully_compactified=0 "
                f"ORDER BY RANDOM() LIMIT {number}"
            ) as cursor:
                rows = await cursor.fetchall()
        else:
            # Since orphan blocks do not get compactified, we need to check whether all blocks with a
            # certain height are not compact. And if we do have compact orphan blocks, then all that
            # happens is that the occasional chain block stays uncompact - not ideal, but harmless.
            async with self.db.execute(
                f"SELECT height FROM full_blocks GROUP BY height HAVING sum(is_fully_compactified)=0 "
                f"ORDER BY RANDOM() LIMIT {number}"
            ) as cursor:
                rows = await cursor.fetchall()

        heights = [int(row[0]) for row in rows]

        return heights

    async def count_compactified_blocks(self) -> int:
        async with self.db.execute("select count(*) from full_blocks where is_fully_compactified=1") as cursor:
            row = await cursor.fetchone()

        assert row is not None

        [count] = row
        return int(count)

    async def count_uncompactified_blocks(self) -> int:
        async with self.db.execute("select count(*) from full_blocks where is_fully_compactified=0") as cursor:
            row = await cursor.fetchone()

        assert row is not None

        [count] = row
        return int(count)
Ejemplo n.º 7
0
    def test_lru_cache(self):
        cache = LRUCache(5)

        assert cache.get(b"0") is None

        assert len(cache.cache) == 0
        cache.put(b"0", 1)
        assert len(cache.cache) == 1
        assert cache.get(b"0") == 1
        cache.put(b"0", 2)
        cache.put(b"0", 3)
        cache.put(b"0", 4)
        cache.put(b"0", 6)
        assert cache.get(b"0") == 6
        assert len(cache.cache) == 1

        cache.put(b"1", 1)
        assert len(cache.cache) == 2
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        cache.put(b"2", 2)
        assert len(cache.cache) == 3
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        cache.put(b"3", 3)
        assert len(cache.cache) == 4
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        assert cache.get(b"3") == 3
        cache.put(b"4", 4)
        assert len(cache.cache) == 5
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        assert cache.get(b"4") == 4
        cache.put(b"5", 5)
        assert cache.get(b"5") == 5
        assert len(cache.cache) == 5
        print(cache.cache)
        assert cache.get(b"3") is None  # 3 is least recently used
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        cache.put(b"7", 7)
        assert len(cache.cache) == 5
        assert cache.get(b"0") is None
        assert cache.get(b"1") == 1
Ejemplo n.º 8
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache
    db_wrapper: DBWrapper
    ses_challenge_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute("pragma journal_mode=wal")
        await self.db.execute("pragma synchronous=2")
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
            "  is_block tinyint, is_fully_compactified tinyint, block blob)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint,"
            "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
        )

        # todo remove in v1.2
        await self.db.execute("DROP TABLE IF EXISTS sub_epoch_segments_v2")

        # Sub epoch segments for weight proofs
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY, challenge_segments blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)"
        )
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on full_blocks(is_block)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on block_records(is_block)")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self

    async def add_full_block(self, header_hash: bytes32, block: FullBlock,
                             block_record: BlockRecord) -> None:
        self.block_cache.put(header_hash, block)
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
            (
                header_hash.hex(),
                block.height,
                int(block.is_transaction_block()),
                int(block.is_fully_compactified()),
                bytes(block),
            ),
        )

        await cursor_1.close()

        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
            (
                header_hash.hex(),
                block.prev_header_hash.hex(),
                block.height,
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
                block.is_transaction_block(),
            ),
        )
        await cursor_2.close()

    async def persist_sub_epoch_challenge_segments(
            self, ses_block_hash: bytes32,
            segments: List[SubEpochChallengeSegment]) -> None:
        async with self.db_wrapper.lock:
            cursor_1 = await self.db.execute(
                "INSERT OR REPLACE INTO sub_epoch_segments_v3 VALUES(?, ?)",
                (ses_block_hash.hex(), bytes(SubEpochSegments(segments))),
            )
            await cursor_1.close()
            await self.db.commit()

    async def get_sub_epoch_challenge_segments(
        self,
        ses_block_hash: bytes32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cached = self.ses_challenge_cache.get(ses_block_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v3 WHERE ses_block_hash=?",
            (ses_block_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            challenge_segments = SubEpochSegments.from_bytes(
                row[0]).challenge_segments
            self.ses_challenge_cache.put(ses_block_hash, challenge_segments)
            return challenge_segments
        return None

    def rollback_cache_block(self, header_hash: bytes32):
        try:
            self.block_cache.remove(header_hash)
        except KeyError:
            # this is best effort. When rolling back, we may not have added the
            # block to the cache yet
            pass

    async def get_full_block(self,
                             header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return cached
        log.debug(f"cache miss for block {header_hash.hex()}")
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            block = FullBlock.from_bytes(row[0])
            self.block_cache.put(header_hash, block)
            return block
        return None

    async def get_full_block_bytes(self,
                                   header_hash: bytes32) -> Optional[bytes]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return bytes(cached)
        log.debug(f"cache miss for block {header_hash.hex()}")
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return row[0]
        return None

    async def get_full_blocks_at(self,
                                 heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [FullBlock.from_bytes(row[0]) for row in rows]

    async def get_block_records_by_hash(self, header_hashes: List[bytes32]):
        """
        Returns a list of Block Records, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """
        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = f'SELECT block from block_records WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            block_rec: BlockRecord = BlockRecord.from_bytes(row[0])
            all_blocks[block_rec.header_hash] = block_rec
        ret: List[BlockRecord] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_blocks_by_hash(
            self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = (
            f'SELECT header_hash, block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        )
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, FullBlock] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            full_block: FullBlock = FullBlock.from_bytes(row[1])
            all_blocks[header_hash] = full_block
            self.block_cache.put(header_hash, full_block)
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        peak_row = await res.fetchone()
        await res.close()
        if peak_row is None:
            return {}, None

        formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak_row[2] - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])
        return ret, bytes.fromhex(peak_row[0])

    async def get_peak_height_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def is_fully_compactified(self,
                                    header_hash: bytes32) -> Optional[bool]:
        cursor = await self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return bool(row[0])

    async def get_random_not_compactified(self, number: int) -> List[int]:
        # Since orphan blocks do not get compactified, we need to check whether all blocks with a
        # certain height are not compact. And if we do have compact orphan blocks, then all that
        # happens is that the occasional chain block stays uncompact - not ideal, but harmless.
        cursor = await self.db.execute(
            f"SELECT height FROM full_blocks GROUP BY height HAVING sum(is_fully_compactified)=0 "
            f"ORDER BY RANDOM() LIMIT {number}")
        rows = await cursor.fetchall()
        await cursor.close()

        heights = []
        for row in rows:
            heights.append(int(row[0]))

        return heights