Esempio n. 1
0
    async def sync_job(self) -> None:
        while True:
            self.log.info("Loop start in sync job")
            if self._shut_down is True:
                break
            asyncio.create_task(self.check_new_peak())
            await self.sync_event.wait()
            self.last_new_peak_messages = LRUCache(5)
            self.sync_event.clear()

            if self._shut_down is True:
                break
            try:
                assert self.wallet_state_manager is not None
                self.wallet_state_manager.set_sync_mode(True)
                await self._sync()
            except Exception as e:
                tb = traceback.format_exc()
                self.log.error(f"Loop exception in sync {e}. {tb}")
            finally:
                if self.wallet_state_manager is not None:
                    self.wallet_state_manager.set_sync_mode(False)
                for peer, peak in self.last_new_peak_messages.cache.items():
                    asyncio.create_task(self.new_peak_wallet(peak, peer))
            self.log.info("Loop end in sync job")
Esempio n. 2
0
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS header_blocks(header_hash text PRIMARY KEY, height int,"
            " timestamp int, block blob)"
        )

        await self.db.execute("CREATE INDEX IF NOT EXISTS header_hash on header_blocks(header_hash)")

        await self.db.execute("CREATE INDEX IF NOT EXISTS timestamp on header_blocks(timestamp)")

        await self.db.execute("CREATE INDEX IF NOT EXISTS height on header_blocks(height)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint, weight bigint, total_iters text,"
            "block blob, sub_epoch_summary blob, is_peak tinyint)"
        )

        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS additional_coin_spends(header_hash text PRIMARY KEY, spends_list_blob blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute("CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute("CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute("CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.commit()
        self.block_cache = LRUCache(1000)
        return self
Esempio n. 3
0
def get_pairings(cache: LRUCache, pks: List[G1Element], msgs: List[bytes],
                 force_cache: bool) -> List[GTElement]:
    pairings: List[Optional[GTElement]] = []
    missing_count: int = 0
    for pk, msg in zip(pks, msgs):
        aug_msg: bytes = bytes(pk) + msg
        h: bytes = bytes(std_hash(aug_msg))
        pairing: Optional[GTElement] = cache.get(h)
        if not force_cache and pairing is None:
            missing_count += 1
            # Heuristic to avoid more expensive sig validation with pairing
            # cache when it's empty and cached pairings won't be useful later
            # (e.g. while syncing)
            if missing_count > len(pks) // 2:
                return []
        pairings.append(pairing)

    for i, pairing in enumerate(pairings):
        if pairing is None:
            aug_msg = bytes(pks[i]) + msgs[i]
            aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg)
            pairing = pks[i].pair(aug_hash)

            h = bytes(std_hash(aug_msg))
            cache.put(h, pairing)
            pairings[i] = pairing

    return pairings
Esempio n. 4
0
    async def create(cls,
                     db_wrapper: DBWrapper,
                     cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                "CREATE TABLE IF NOT EXISTS coin_record("
                "coin_name blob PRIMARY KEY,"
                " confirmed_index bigint,"
                " spent_index bigint,"  # if this is zero, it means the coin has not been spent
                " coinbase int,"
                " puzzle_hash blob,"
                " coin_parent blob,"
                " amount blob,"  # we use a blob of 8 bytes to store uint64
                " timestamp bigint)")

        else:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                ("CREATE TABLE IF NOT EXISTS coin_record("
                 "coin_name text PRIMARY KEY,"
                 " confirmed_index bigint,"
                 " spent_index bigint,"
                 " spent int,"
                 " coinbase int,"
                 " puzzle_hash text,"
                 " coin_parent text,"
                 " amount blob,"
                 " timestamp bigint)"))

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_parent_index on coin_record(coin_parent)"
        )

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self
Esempio n. 5
0
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute("pragma journal_mode=wal")
        await self.db.execute("pragma synchronous=2")
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
            "  is_block tinyint, is_fully_compactified tinyint, block blob)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint,"
            "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
        )

        # todo remove in v1.2
        await self.db.execute("DROP TABLE IF EXISTS sub_epoch_segments_v2")

        # Sub epoch segments for weight proofs
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY, challenge_segments blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)"
        )
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on full_blocks(is_block)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on block_records(is_block)")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self
Esempio n. 6
0
    async def create(cls):
        self = cls()
        self.db_path = Path("pooldb.sqlite")
        self.connection = await aiosqlite.connect(self.db_path)
        self.lock = asyncio.Lock()
        await self.connection.execute("pragma journal_mode=wal")
        await self.connection.execute("pragma synchronous=2")
        await self.connection.execute(("CREATE TABLE IF NOT EXISTS farmer("
                                       "singleton_genesis text PRIMARY KEY,"
                                       " owner_public_key text,"
                                       " pool_puzzle_hash text,"
                                       " relative_lock_height bigint,"
                                       " p2_singleton_puzzle_hash text,"
                                       " blockchain_height bigint,"
                                       " singleton_coin_id text,"
                                       " points bigint,"
                                       " difficulty bigint,"
                                       " rewards_target text,"
                                       " is_pool_member tinyint)"))

        # Useful for reorg lookups
        await self.connection.execute(
            "CREATE INDEX IF NOT EXISTS scan_ph on farmer(p2_singleton_puzzle_hash)"
        )

        await self.connection.commit()
        self.coin_record_cache = LRUCache(1000)

        return self
    def test_cached_bls(self):
        n_keys = 10
        seed = b"a" * 31
        sks = [AugSchemeMPL.key_gen(seed + bytes([i])) for i in range(n_keys)]
        pks = [bytes(sk.get_g1()) for sk in sks]

        msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys)]
        sigs = [AugSchemeMPL.sign(sk, msg) for sk, msg in zip(sks, msgs)]
        agg_sig = AugSchemeMPL.aggregate(sigs)

        pks_half = pks[: n_keys // 2]
        msgs_half = msgs[: n_keys // 2]
        sigs_half = sigs[: n_keys // 2]
        agg_sig_half = AugSchemeMPL.aggregate(sigs_half)

        assert AugSchemeMPL.aggregate_verify([G1Element.from_bytes(pk) for pk in pks], msgs, agg_sig)

        # Verify with empty cache and populate it
        assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, True)
        # Verify with partial cache hit
        assert cached_bls.aggregate_verify(pks, msgs, agg_sig, True)
        # Verify with full cache hit
        assert cached_bls.aggregate_verify(pks, msgs, agg_sig)

        # Use a small cache which can not accommodate all pairings
        local_cache = LRUCache(n_keys // 2)
        # Verify signatures and cache pairings one at a time
        for pk, msg, sig in zip(pks_half, msgs_half, sigs_half):
            assert cached_bls.aggregate_verify([pk], [msg], sig, True, local_cache)
        # Verify the same messages with aggregated signature (full cache hit)
        assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, False, local_cache)
        # Verify more messages (partial cache hit)
        assert cached_bls.aggregate_verify(pks, msgs, agg_sig, False, local_cache)
Esempio n. 8
0
    def __init__(self, private_key: PrivateKey, config: Dict,
                 constants: ConsensusConstants):
        self.log = logging.getLogger(__name__)
        self.private_key = private_key
        self.public_key: G1Element = private_key.get_g1()
        self.config = config
        self.constants = constants
        self.node_rpc_client = None

        self.store: Optional[PoolStore] = None

        self.pool_fee = 0.01

        # This number should be held constant and be consistent for every pool in the network
        self.iters_limit = 734000000

        # This number should not be changed, since users will put this into their singletons
        self.relative_lock_height = uint32(100)

        # TODO: potentially tweak these numbers for security and performance
        self.pool_url = "https://myreferencepool.com"
        self.min_difficulty = uint64(
            100)  # 100 difficulty is about 1 proof a day per plot
        self.default_difficulty: uint64 = uint64(100)
        self.max_difficulty = uint64(1000)

        # TODO: store this information in a persistent DB
        self.account_points: Dict[bytes, int] = {
        }  # Points are added by submitting partials
        self.account_rewards_targets: Dict[bytes, bytes] = {}

        self.pending_point_partials: Optional[asyncio.Queue] = None
        self.recent_points_added: LRUCache = LRUCache(20000)

        # This is where the block rewards will get paid out to. The pool needs to support this address forever,
        # since the farmers will encode it into their singleton on the blockchain.
        self.default_pool_puzzle_hash: bytes32 = decode_puzzle_hash(
            "xch12ma5m7sezasgh95wkyr8470ngryec27jxcvxcmsmc4ghy7c4njssnn623q")

        # We need to check for slow farmers. If farmers cannot submit proofs in time, they won't be able to win
        # any rewards either. This number can be tweaked to be more or less strict. More strict ensures everyone
        # gets high rewards, but it might cause some of the slower farmers to not be able to participate in the pool.
        self.partial_time_limit: int = 25

        # There is always a risk of a reorg, in which case we cannot reward farmers that submitted partials in that
        # reorg. That is why we have a time delay before changing any account points.
        self.partial_confirmation_delay: int = 300

        self.full_node_client: Optional[FullNodeRpcClient] = None
        self.confirm_partials_loop_task: Optional[asyncio.Task] = None
        self.difficulty_change_time: Dict[bytes32, uint64] = {}

        self.scan_p2_singleton_puzzle_hashes: Set[bytes32] = set()
        self.blockchain_state = {"peak": None}
Esempio n. 9
0
    async def create(cls,
                     db_wrapper: DBWrapper,
                     cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db
        await self.coin_record_db.execute("pragma journal_mode=wal")
        await self.coin_record_db.execute("pragma synchronous=2")
        await self.coin_record_db.execute(
            ("CREATE TABLE IF NOT EXISTS coin_record("
             "coin_name text PRIMARY KEY,"
             " confirmed_index bigint,"
             " spent_index bigint,"
             " spent int,"
             " coinbase int,"
             " puzzle_hash text,"
             " coin_parent text,"
             " amount blob,"
             " timestamp bigint)"))

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_spent on coin_record(spent)")

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)"
        )

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self
Esempio n. 10
0
    def __init__(
        self,
        config: Dict,
        keychain: Keychain,
        root_path: Path,
        consensus_constants: ConsensusConstants,
        name: str = None,
    ):
        self.config = config
        self.constants = consensus_constants
        self.root_path = root_path
        if name:
            self.log = logging.getLogger(name)
        else:
            self.log = logging.getLogger(__name__)
        # Normal operation data
        self.cached_blocks: Dict = {}
        self.future_block_hashes: Dict = {}
        self.keychain = keychain

        # Sync data
        self._shut_down = False
        self.proof_hashes: List = []
        self.header_hashes: List = []
        self.header_hashes_error = False
        self.short_sync_threshold = 15  # Change the test when changing this
        self.potential_blocks_received: Dict = {}
        self.potential_header_hashes: Dict = {}
        self.state_changed_callback = None
        self.wallet_state_manager = None
        self.backup_initialized = False  # Delay first launch sync after user imports backup info or decides to skip
        self.server = None
        self.wsm_close_task = None
        self.sync_task: Optional[asyncio.Task] = None
        self.new_peak_lock: Optional[asyncio.Lock] = None
        self.logged_in_fingerprint: Optional[int] = None
        self.peer_task = None
        self.logged_in = False
        self.last_new_peak_messages = LRUCache(5)
Esempio n. 11
0
 def __init__(self, constants: ConsensusConstants):
     self.candidate_blocks = {}
     self.candidate_backup_blocks = {}
     self.seen_unfinished_blocks = set()
     self.unfinished_blocks = {}
     self.finished_sub_slots = []
     self.future_eos_cache = {}
     self.future_sp_cache = {}
     self.future_ip_cache = {}
     self.recent_signage_points = LRUCache(500)
     self.recent_eos = LRUCache(50)
     self.requesting_unfinished_blocks = set()
     self.previous_generator = None
     self.future_cache_key_times = {}
     self.constants = constants
     self.clear_slots()
     self.initialize_genesis_sub_slot()
     self.pending_tx_request = {}
     self.peers_with_tx = {}
     self.tx_fetch_tasks = {}
     self.serialized_wp_message = None
     self.serialized_wp_message_tip = None
Esempio n. 12
0
def validate_clvm_and_signature(
        spend_bundle_bytes: bytes, max_cost: int, cost_per_byte: int,
        additional_data: bytes
) -> Tuple[Optional[Err], bytes, Dict[bytes, bytes]]:
    """
    Validates CLVM and aggregate signature for a spendbundle. This is meant to be called under a ProcessPoolExecutor
    in order to validate the heavy parts of a transction in a different thread. Returns an optional error,
    the NPCResult and a cache of the new pairings validated (if not error)
    """
    try:
        bundle: SpendBundle = SpendBundle.from_bytes(spend_bundle_bytes)
        program = simple_solution_generator(bundle)
        # npc contains names of the coins removed, puzzle_hashes and their spend conditions
        result: NPCResult = get_name_puzzle_conditions(
            program, max_cost, cost_per_byte=cost_per_byte, mempool_mode=True)

        if result.error is not None:
            return Err(result.error), b"", {}

        pks: List[G1Element] = []
        msgs: List[bytes32] = []
        # TODO: address hint error and remove ignore
        #       error: Incompatible types in assignment (expression has type "List[bytes]", variable has type
        #       "List[bytes32]")  [assignment]
        pks, msgs = pkm_pairs(result.npc_list,
                              additional_data)  # type: ignore[assignment]

        # Verify aggregated signature
        cache: LRUCache = LRUCache(10000)
        if not cached_bls.aggregate_verify(
                pks, msgs, bundle.aggregated_signature, True, cache):
            return Err.BAD_AGGREGATE_SIGNATURE, b"", {}
        new_cache_entries: Dict[bytes, bytes] = {}
        for k, v in cache.cache.items():
            new_cache_entries[k] = bytes(v)
    except ValidationError as e:
        return e.code, b"", {}
    except Exception:
        return Err.UNKNOWN, b"", {}

    return None, bytes(result), new_cache_entries
Esempio n. 13
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache

    @classmethod
    async def create(cls, connection: aiosqlite.Connection):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db = connection
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
            "  is_block tinyint, is_fully_compactified tinyint, block blob)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint,"
            "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
        )

        # todo remove in v1.2
        await self.db.execute("DROP TABLE IF EXISTS sub_epoch_segments")

        # Sub epoch segments for weight proofs
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v2(ses_height bigint PRIMARY KEY, challenge_segments blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)"
        )
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on full_blocks(is_block)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on block_records(is_block)")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        return self

    async def begin_transaction(self) -> None:
        # Also locks the coin store, since both stores must be updated at once
        cursor = await self.db.execute("BEGIN TRANSACTION")
        await cursor.close()

    async def commit_transaction(self) -> None:
        await self.db.commit()

    async def rollback_transaction(self) -> None:
        # Also rolls back the coin store, since both stores must be updated at once
        cursor = await self.db.execute("ROLLBACK")
        await cursor.close()

    async def add_full_block(self, block: FullBlock,
                             block_record: BlockRecord) -> None:
        self.block_cache.put(block.header_hash, block)
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
            (
                block.header_hash.hex(),
                block.height,
                int(block.is_transaction_block()),
                int(block.is_fully_compactified()),
                bytes(block),
            ),
        )

        await cursor_1.close()

        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
            (
                block.header_hash.hex(),
                block.prev_header_hash.hex(),
                block.height,
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
                block.is_transaction_block(),
            ),
        )
        await cursor_2.close()
        await self.db.commit()

    async def persist_sub_epoch_challenge_segments(
            self, sub_epoch_summary_height: uint32,
            segments: List[SubEpochChallengeSegment]) -> None:
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO sub_epoch_segments_v2 VALUES(?, ?)",
            (sub_epoch_summary_height, bytes(SubEpochSegments(segments))),
        )
        await cursor_1.close()

    async def get_sub_epoch_challenge_segments(
        self,
        sub_epoch_summary_height: uint32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cursor = await self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v2 WHERE ses_height=?",
            (sub_epoch_summary_height, ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return SubEpochSegments.from_bytes(row[0]).challenge_segments
        return None

    async def delete_sub_epoch_challenge_segments(self,
                                                  fork_height: uint32) -> None:
        cursor = await self.db.execute(
            "delete from sub_epoch_segments_v2 WHERE ses_height>?",
            (fork_height, ))
        await cursor.close()

    async def get_full_block(self,
                             header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return FullBlock.from_bytes(row[0])
        return None

    async def get_full_blocks_at(self,
                                 heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [FullBlock.from_bytes(row[0]) for row in rows]

    async def get_block_records_at(self,
                                   heights: List[uint32]) -> List[BlockRecord]:
        if len(heights) == 0:
            return []
        heights_db = tuple(heights)
        formatted_str = (
            f'SELECT block from block_records WHERE height in ({"?," * (len(heights_db) - 1)}?) ORDER BY height ASC;'
        )
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [BlockRecord.from_bytes(row[0]) for row in rows]

    async def get_blocks_by_hash(
            self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = f'SELECT block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, FullBlock] = {}
        for row in rows:
            full_block: FullBlock = FullBlock.from_bytes(row[0])
            all_blocks[full_block.header_hash] = full_block
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_header_blocks_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, HeaderBlock]:

        formatted_str = f"SELECT header_hash,block from full_blocks WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, HeaderBlock] = {}
        for row in rows:
            # Ugly hack, until full_block.get_block_header is rewritten as part of generator runner change
            await asyncio.sleep(0.001)
            header_hash = bytes.fromhex(row[0])
            full_block: FullBlock = FullBlock.from_bytes(row[1])
            ret[header_hash] = full_block.get_block_header()

        return ret

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records(
        self, ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """
        cursor = await self.db.execute("SELECT * from block_records")
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        peak: Optional[bytes32] = None
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[3])
            if row[5]:
                assert peak is None  # Sanity check, only one peak
                peak = header_hash
        return ret, peak

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        peak_row = await res.fetchone()
        await res.close()
        if peak_row is None:
            return {}, None

        formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak_row[2] - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])
        return ret, bytes.fromhex(peak_row[0])

    async def get_peak_height_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def is_fully_compactified(self,
                                    header_hash: bytes32) -> Optional[bool]:
        cursor = await self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return bool(row[0])

    async def get_first_not_compactified(self,
                                         min_height: int) -> Optional[int]:
        cursor = await self.db.execute(
            "SELECT MIN(height) from full_blocks WHERE is_fully_compactified=0 AND height>=?",
            (min_height, ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return int(row[0])
Esempio n. 14
0
class FullNodeStore:
    constants: ConsensusConstants

    # Blocks which we have created, but don't have plot signatures yet, so not yet "unfinished blocks"
    candidate_blocks: Dict[bytes32, Tuple[uint32, UnfinishedBlock]]
    candidate_backup_blocks: Dict[bytes32, Tuple[uint32, UnfinishedBlock]]

    # Header hashes of unfinished blocks that we have seen recently
    seen_unfinished_blocks: set

    # Unfinished blocks, keyed from reward hash
    unfinished_blocks: Dict[bytes32, Tuple[uint32, UnfinishedBlock,
                                           PreValidationResult]]

    # Finished slots and sps from the peak's slot onwards
    # We store all 32 SPs for each slot, starting as 32 Nones and filling them as we go
    # Also stores the total iters at the end of slot
    # For the first sub-slot, EndOfSlotBundle is None
    finished_sub_slots: List[Tuple[Optional[EndOfSubSlotBundle],
                                   List[Optional[SignagePoint]], uint128]]

    # These caches maintain objects which depend on infused blocks in the reward chain, that we
    # might receive before the blocks themselves. The dict keys are the reward chain challenge hashes.

    # End of slots which depend on infusions that we don't have
    future_eos_cache: Dict[bytes32, List[EndOfSubSlotBundle]]

    # Signage points which depend on infusions that we don't have
    future_sp_cache: Dict[bytes32, List[Tuple[uint8, SignagePoint]]]

    # Infusion point VDFs which depend on infusions that we don't have
    future_ip_cache: Dict[bytes32, List[timelord_protocol.NewInfusionPointVDF]]

    # This stores the time that each key was added to the future cache, so we can clear old keys
    future_cache_key_times: Dict[bytes32, int]

    # These recent caches are for pooling support
    recent_signage_points: LRUCache
    recent_eos: LRUCache

    # Partial hashes of unfinished blocks we are requesting
    requesting_unfinished_blocks: Set[bytes32]

    previous_generator: Optional[CompressorArg]
    pending_tx_request: Dict[bytes32, bytes32]  # tx_id: peer_id
    peers_with_tx: Dict[bytes32, Set[bytes32]]  # tx_id: Set[peer_ids}
    tx_fetch_tasks: Dict[bytes32, asyncio.Task]  # Task id: task
    serialized_wp_message: Optional[Message]
    serialized_wp_message_tip: Optional[bytes32]

    def __init__(self, constants: ConsensusConstants):
        self.candidate_blocks = {}
        self.candidate_backup_blocks = {}
        self.seen_unfinished_blocks = set()
        self.unfinished_blocks = {}
        self.finished_sub_slots = []
        self.future_eos_cache = {}
        self.future_sp_cache = {}
        self.future_ip_cache = {}
        self.recent_signage_points = LRUCache(500)
        self.recent_eos = LRUCache(50)
        self.requesting_unfinished_blocks = set()
        self.previous_generator = None
        self.future_cache_key_times = {}
        self.constants = constants
        self.clear_slots()
        self.initialize_genesis_sub_slot()
        self.pending_tx_request = {}
        self.peers_with_tx = {}
        self.tx_fetch_tasks = {}
        self.serialized_wp_message = None
        self.serialized_wp_message_tip = None

    def add_candidate_block(self,
                            quality_string: bytes32,
                            height: uint32,
                            unfinished_block: UnfinishedBlock,
                            backup: bool = False):
        if backup:
            self.candidate_backup_blocks[quality_string] = (height,
                                                            unfinished_block)
        else:
            self.candidate_blocks[quality_string] = (height, unfinished_block)

    def get_candidate_block(
            self,
            quality_string: bytes32,
            backup: bool = False) -> Optional[Tuple[uint32, UnfinishedBlock]]:
        if backup:
            return self.candidate_backup_blocks.get(quality_string, None)
        else:
            return self.candidate_blocks.get(quality_string, None)

    def clear_candidate_blocks_below(self, height: uint32) -> None:
        del_keys = []
        for key, value in self.candidate_blocks.items():
            if value[0] < height:
                del_keys.append(key)
        for key in del_keys:
            try:
                del self.candidate_blocks[key]
            except KeyError:
                pass
        del_keys = []
        for key, value in self.candidate_backup_blocks.items():
            if value[0] < height:
                del_keys.append(key)
        for key in del_keys:
            try:
                del self.candidate_backup_blocks[key]
            except KeyError:
                pass

    def seen_unfinished_block(self, object_hash: bytes32) -> bool:
        if object_hash in self.seen_unfinished_blocks:
            return True
        self.seen_unfinished_blocks.add(object_hash)
        return False

    def clear_seen_unfinished_blocks(self) -> None:
        self.seen_unfinished_blocks.clear()

    def add_unfinished_block(self, height: uint32,
                             unfinished_block: UnfinishedBlock,
                             result: PreValidationResult) -> None:
        self.unfinished_blocks[unfinished_block.partial_hash] = (
            height, unfinished_block, result)

    def get_unfinished_block(
            self,
            unfinished_reward_hash: bytes32) -> Optional[UnfinishedBlock]:
        result = self.unfinished_blocks.get(unfinished_reward_hash, None)
        if result is None:
            return None
        return result[1]

    def get_unfinished_block_result(
            self,
            unfinished_reward_hash: bytes32) -> Optional[PreValidationResult]:
        result = self.unfinished_blocks.get(unfinished_reward_hash, None)
        if result is None:
            return None
        return result[2]

    def get_unfinished_blocks(
        self
    ) -> Dict[bytes32, Tuple[uint32, UnfinishedBlock, PreValidationResult]]:
        return self.unfinished_blocks

    def clear_unfinished_blocks_below(self, height: uint32) -> None:
        del_keys: List[bytes32] = []
        for partial_reward_hash, (unf_height, unfinished_block,
                                  _) in self.unfinished_blocks.items():
            if unf_height < height:
                del_keys.append(partial_reward_hash)
        for del_key in del_keys:
            del self.unfinished_blocks[del_key]

    def remove_unfinished_block(self, partial_reward_hash: bytes32):
        if partial_reward_hash in self.unfinished_blocks:
            del self.unfinished_blocks[partial_reward_hash]

    def add_to_future_ip(
            self, infusion_point: timelord_protocol.NewInfusionPointVDF):
        ch: bytes32 = infusion_point.reward_chain_ip_vdf.challenge
        if ch not in self.future_ip_cache:
            self.future_ip_cache[ch] = []
        self.future_ip_cache[ch].append(infusion_point)

    def in_future_sp_cache(self, signage_point: SignagePoint,
                           index: uint8) -> bool:
        if signage_point.rc_vdf is None:
            return False

        if signage_point.rc_vdf.challenge not in self.future_sp_cache:
            return False
        for cache_index, cache_sp in self.future_sp_cache[
                signage_point.rc_vdf.challenge]:
            if cache_index == index and cache_sp.rc_vdf == signage_point.rc_vdf:
                return True
        return False

    def add_to_future_sp(self, signage_point: SignagePoint, index: uint8):
        # We are missing a block here
        if (signage_point.cc_vdf is None or signage_point.rc_vdf is None
                or signage_point.cc_proof is None
                or signage_point.rc_proof is None):
            return None
        if signage_point.rc_vdf.challenge not in self.future_sp_cache:
            self.future_sp_cache[signage_point.rc_vdf.challenge] = []
        if self.in_future_sp_cache(signage_point, index):
            return None

        self.future_cache_key_times[signage_point.rc_vdf.challenge] = int(
            time.time())
        self.future_sp_cache[signage_point.rc_vdf.challenge].append(
            (index, signage_point))
        log.info(
            f"Don't have rc hash {signage_point.rc_vdf.challenge}. caching signage point {index}."
        )

    def get_future_ip(
        self, rc_challenge_hash: bytes32
    ) -> List[timelord_protocol.NewInfusionPointVDF]:
        return self.future_ip_cache.get(rc_challenge_hash, [])

    def clear_old_cache_entries(self) -> None:
        current_time: int = int(time.time())
        remove_keys: List[bytes32] = []
        for rc_hash, time_added in self.future_cache_key_times.items():
            if current_time - time_added > 3600:
                remove_keys.append(rc_hash)
        for k in remove_keys:
            self.future_cache_key_times.pop(k, None)
            self.future_ip_cache.pop(k, [])
            self.future_eos_cache.pop(k, [])
            self.future_sp_cache.pop(k, [])

    def clear_slots(self):
        self.finished_sub_slots.clear()

    def get_sub_slot(
        self, challenge_hash: bytes32
    ) -> Optional[Tuple[EndOfSubSlotBundle, int, uint128]]:
        assert len(self.finished_sub_slots) >= 1
        for index, (sub_slot, _,
                    total_iters) in enumerate(self.finished_sub_slots):
            if sub_slot is not None and sub_slot.challenge_chain.get_hash(
            ) == challenge_hash:
                return sub_slot, index, total_iters
        return None

    def initialize_genesis_sub_slot(self):
        self.clear_slots()
        self.finished_sub_slots = [
            (None, [None] * self.constants.NUM_SPS_SUB_SLOT, uint128(0))
        ]

    def new_finished_sub_slot(
        self,
        eos: EndOfSubSlotBundle,
        blocks: BlockchainInterface,
        peak: Optional[BlockRecord],
        peak_full_block: Optional[FullBlock],
    ) -> Optional[List[timelord_protocol.NewInfusionPointVDF]]:
        """
        Returns false if not added. Returns a list if added. The list contains all infusion points that depended
        on this sub slot
        """
        assert len(self.finished_sub_slots) >= 1
        assert (peak is None) == (peak_full_block is None)

        last_slot, _, last_slot_iters = self.finished_sub_slots[-1]

        cc_challenge: bytes32 = (last_slot.challenge_chain.get_hash()
                                 if last_slot is not None else
                                 self.constants.GENESIS_CHALLENGE)
        rc_challenge: bytes32 = (last_slot.reward_chain.get_hash()
                                 if last_slot is not None else
                                 self.constants.GENESIS_CHALLENGE)
        icc_challenge: Optional[bytes32] = None
        icc_iters: Optional[uint64] = None

        # Skip if already present
        for slot, _, _ in self.finished_sub_slots:
            if slot == eos:
                return []

        if eos.challenge_chain.challenge_chain_end_of_slot_vdf.challenge != cc_challenge:
            # This slot does not append to our next slot
            # This prevent other peers from appending fake VDFs to our cache
            return None

        if peak is None:
            sub_slot_iters = self.constants.SUB_SLOT_ITERS_STARTING
        else:
            sub_slot_iters = peak.sub_slot_iters

        total_iters = uint128(last_slot_iters + sub_slot_iters)

        if peak is not None and peak.total_iters > last_slot_iters:
            # Peak is in this slot
            rc_challenge = eos.reward_chain.end_of_slot_vdf.challenge
            cc_start_element = peak.challenge_vdf_output
            iters = uint64(total_iters - peak.total_iters)
            if peak.reward_infusion_new_challenge != rc_challenge:
                # We don't have this challenge hash yet
                if rc_challenge not in self.future_eos_cache:
                    self.future_eos_cache[rc_challenge] = []
                self.future_eos_cache[rc_challenge].append(eos)
                self.future_cache_key_times[rc_challenge] = int(time.time())
                log.info(
                    f"Don't have challenge hash {rc_challenge}, caching EOS")
                return None

            if peak.deficit == self.constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK:
                icc_start_element = None
            elif peak.deficit == self.constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK - 1:
                icc_start_element = ClassgroupElement.get_default_element()
            else:
                icc_start_element = peak.infused_challenge_vdf_output

            if peak.deficit < self.constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK:
                curr = peak
                while not curr.first_in_sub_slot and not curr.is_challenge_block(
                        self.constants):
                    curr = blocks.block_record(curr.prev_hash)
                if curr.is_challenge_block(self.constants):
                    icc_challenge = curr.challenge_block_info_hash
                    icc_iters = uint64(total_iters - curr.total_iters)
                else:
                    assert curr.finished_infused_challenge_slot_hashes is not None
                    icc_challenge = curr.finished_infused_challenge_slot_hashes[
                        -1]
                    icc_iters = sub_slot_iters
                assert icc_challenge is not None

            if can_finish_sub_and_full_epoch(
                    self.constants,
                    blocks,
                    peak.height,
                    peak.prev_hash,
                    peak.deficit,
                    peak.sub_epoch_summary_included is not None,
            )[0]:
                assert peak_full_block is not None
                ses: Optional[SubEpochSummary] = next_sub_epoch_summary(
                    self.constants, blocks, peak.required_iters,
                    peak_full_block, True)
                if ses is not None:
                    if eos.challenge_chain.subepoch_summary_hash != ses.get_hash(
                    ):
                        log.warning(
                            f"SES not correct {ses.get_hash(), eos.challenge_chain}"
                        )
                        return None
                else:
                    if eos.challenge_chain.subepoch_summary_hash is not None:
                        log.warning("SES not correct, should be None")
                        return None
        else:
            # This is on an empty slot
            cc_start_element = ClassgroupElement.get_default_element()
            icc_start_element = ClassgroupElement.get_default_element()
            iters = sub_slot_iters
            icc_iters = sub_slot_iters

            # The icc should only be present if the previous slot had an icc too, and not deficit 0 (just finished slot)
            icc_challenge = (last_slot.infused_challenge_chain.get_hash()
                             if last_slot is not None
                             and last_slot.infused_challenge_chain is not None
                             and last_slot.reward_chain.deficit !=
                             self.constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK else
                             None)

        # Validate cc VDF
        partial_cc_vdf_info = VDFInfo(
            cc_challenge,
            iters,
            eos.challenge_chain.challenge_chain_end_of_slot_vdf.output,
        )
        # The EOS will have the whole sub-slot iters, but the proof is only the delta, from the last peak
        if eos.challenge_chain.challenge_chain_end_of_slot_vdf != dataclasses.replace(
                partial_cc_vdf_info,
                number_of_iterations=sub_slot_iters,
        ):
            return None
        if (not eos.proofs.challenge_chain_slot_proof.normalized_to_identity
                and not eos.proofs.challenge_chain_slot_proof.is_valid(
                    self.constants,
                    cc_start_element,
                    partial_cc_vdf_info,
                )):
            return None
        if (eos.proofs.challenge_chain_slot_proof.normalized_to_identity
                and not eos.proofs.challenge_chain_slot_proof.is_valid(
                    self.constants,
                    ClassgroupElement.get_default_element(),
                    eos.challenge_chain.challenge_chain_end_of_slot_vdf,
                )):
            return None

        # Validate reward chain VDF
        if not eos.proofs.reward_chain_slot_proof.is_valid(
                self.constants,
                ClassgroupElement.get_default_element(),
                eos.reward_chain.end_of_slot_vdf,
                VDFInfo(rc_challenge, iters,
                        eos.reward_chain.end_of_slot_vdf.output),
        ):
            return None

        if icc_challenge is not None:
            assert icc_start_element is not None
            assert icc_iters is not None
            assert eos.infused_challenge_chain is not None
            assert eos.infused_challenge_chain is not None
            assert eos.proofs.infused_challenge_chain_slot_proof is not None

            partial_icc_vdf_info = VDFInfo(
                icc_challenge,
                iters,
                eos.infused_challenge_chain.
                infused_challenge_chain_end_of_slot_vdf.output,
            )
            # The EOS will have the whole sub-slot iters, but the proof is only the delta, from the last peak
            if eos.infused_challenge_chain.infused_challenge_chain_end_of_slot_vdf != dataclasses.replace(
                    partial_icc_vdf_info,
                    number_of_iterations=icc_iters,
            ):
                return None
            if (not eos.proofs.infused_challenge_chain_slot_proof.
                    normalized_to_identity and
                    not eos.proofs.infused_challenge_chain_slot_proof.is_valid(
                        self.constants, icc_start_element,
                        partial_icc_vdf_info)):
                return None
            if (eos.proofs.infused_challenge_chain_slot_proof.
                    normalized_to_identity and
                    not eos.proofs.infused_challenge_chain_slot_proof.is_valid(
                        self.constants,
                        ClassgroupElement.get_default_element(),
                        eos.infused_challenge_chain.
                        infused_challenge_chain_end_of_slot_vdf,
                    )):
                return None
        else:
            # This is the first sub slot and it's empty, therefore there is no ICC
            if eos.infused_challenge_chain is not None or eos.proofs.infused_challenge_chain_slot_proof is not None:
                return None

        self.finished_sub_slots.append(
            (eos, [None] * self.constants.NUM_SPS_SUB_SLOT, total_iters))

        new_cc_hash = eos.challenge_chain.get_hash()
        self.recent_eos.put(new_cc_hash, (eos, time.time()))

        new_ips: List[timelord_protocol.NewInfusionPointVDF] = []
        for ip in self.future_ip_cache.get(eos.reward_chain.get_hash(), []):
            new_ips.append(ip)

        return new_ips

    def new_signage_point(
        self,
        index: uint8,
        blocks: BlockchainInterface,
        peak: Optional[BlockRecord],
        next_sub_slot_iters: uint64,
        signage_point: SignagePoint,
        skip_vdf_validation=False,
    ) -> bool:
        """
        Returns true if sp successfully added
        """
        assert len(self.finished_sub_slots) >= 1

        if peak is None or peak.height < 2:
            sub_slot_iters = self.constants.SUB_SLOT_ITERS_STARTING
        else:
            sub_slot_iters = peak.sub_slot_iters

        # If we don't have this slot, return False
        if index == 0 or index >= self.constants.NUM_SPS_SUB_SLOT:
            return False
        assert (signage_point.cc_vdf is not None
                and signage_point.cc_proof is not None
                and signage_point.rc_vdf is not None
                and signage_point.rc_proof is not None)
        for sub_slot, sp_arr, start_ss_total_iters in self.finished_sub_slots:
            if sub_slot is None:
                assert start_ss_total_iters == 0
                ss_challenge_hash = self.constants.GENESIS_CHALLENGE
                ss_reward_hash = self.constants.GENESIS_CHALLENGE
            else:
                ss_challenge_hash = sub_slot.challenge_chain.get_hash()
                ss_reward_hash = sub_slot.reward_chain.get_hash()
            if ss_challenge_hash == signage_point.cc_vdf.challenge:
                # If we do have this slot, find the Prev block from SP and validate SP
                if peak is not None and start_ss_total_iters > peak.total_iters:
                    # We are in a future sub slot from the peak, so maybe there is a new SSI
                    checkpoint_size: uint64 = uint64(
                        next_sub_slot_iters // self.constants.NUM_SPS_SUB_SLOT)
                    delta_iters: uint64 = uint64(checkpoint_size * index)
                    future_sub_slot: bool = True
                else:
                    # We are not in a future sub slot from the peak, so there is no new SSI
                    checkpoint_size = uint64(sub_slot_iters //
                                             self.constants.NUM_SPS_SUB_SLOT)
                    delta_iters = uint64(checkpoint_size * index)
                    future_sub_slot = False
                sp_total_iters = start_ss_total_iters + delta_iters

                curr = peak
                if peak is None or future_sub_slot:
                    check_from_start_of_ss = True
                else:
                    check_from_start_of_ss = False
                    while (curr is not None
                           and curr.total_iters > start_ss_total_iters
                           and curr.total_iters > sp_total_iters):
                        if curr.first_in_sub_slot:
                            # Did not find a block where it's iters are before our sp_total_iters, in this ss
                            check_from_start_of_ss = True
                            break
                        curr = blocks.block_record(curr.prev_hash)

                if check_from_start_of_ss:
                    # Check VDFs from start of sub slot
                    cc_vdf_info_expected = VDFInfo(
                        ss_challenge_hash,
                        delta_iters,
                        signage_point.cc_vdf.output,
                    )

                    rc_vdf_info_expected = VDFInfo(
                        ss_reward_hash,
                        delta_iters,
                        signage_point.rc_vdf.output,
                    )
                else:
                    # Check VDFs from curr
                    assert curr is not None
                    cc_vdf_info_expected = VDFInfo(
                        ss_challenge_hash,
                        uint64(sp_total_iters - curr.total_iters),
                        signage_point.cc_vdf.output,
                    )
                    rc_vdf_info_expected = VDFInfo(
                        curr.reward_infusion_new_challenge,
                        uint64(sp_total_iters - curr.total_iters),
                        signage_point.rc_vdf.output,
                    )
                if not signage_point.cc_vdf == dataclasses.replace(
                        cc_vdf_info_expected,
                        number_of_iterations=delta_iters):
                    self.add_to_future_sp(signage_point, index)
                    return False
                if check_from_start_of_ss:
                    start_ele = ClassgroupElement.get_default_element()
                else:
                    assert curr is not None
                    start_ele = curr.challenge_vdf_output
                if not skip_vdf_validation:
                    if not signage_point.cc_proof.normalized_to_identity and not signage_point.cc_proof.is_valid(
                            self.constants,
                            start_ele,
                            cc_vdf_info_expected,
                    ):
                        self.add_to_future_sp(signage_point, index)
                        return False
                    if signage_point.cc_proof.normalized_to_identity and not signage_point.cc_proof.is_valid(
                            self.constants,
                            ClassgroupElement.get_default_element(),
                            signage_point.cc_vdf,
                    ):
                        self.add_to_future_sp(signage_point, index)
                        return False

                if rc_vdf_info_expected.challenge != signage_point.rc_vdf.challenge:
                    # This signage point is probably outdated
                    self.add_to_future_sp(signage_point, index)
                    return False

                if not skip_vdf_validation:
                    if not signage_point.rc_proof.is_valid(
                            self.constants,
                            ClassgroupElement.get_default_element(),
                            signage_point.rc_vdf,
                            rc_vdf_info_expected,
                    ):
                        self.add_to_future_sp(signage_point, index)
                        return False

                sp_arr[index] = signage_point
                self.recent_signage_points.put(
                    signage_point.cc_vdf.output.get_hash(),
                    (signage_point, time.time()))
                return True
        self.add_to_future_sp(signage_point, index)
        return False

    def get_signage_point(self,
                          cc_signage_point: bytes32) -> Optional[SignagePoint]:
        assert len(self.finished_sub_slots) >= 1
        if cc_signage_point == self.constants.GENESIS_CHALLENGE:
            return SignagePoint(None, None, None, None)

        for sub_slot, sps, _ in self.finished_sub_slots:
            if sub_slot is not None and sub_slot.challenge_chain.get_hash(
            ) == cc_signage_point:
                return SignagePoint(None, None, None, None)
            for sp in sps:
                if sp is not None:
                    assert sp.cc_vdf is not None
                    if sp.cc_vdf.output.get_hash() == cc_signage_point:
                        return sp
        return None

    def get_signage_point_by_index(
            self, challenge_hash: bytes32, index: uint8,
            last_rc_infusion: bytes32) -> Optional[SignagePoint]:
        assert len(self.finished_sub_slots) >= 1
        for sub_slot, sps, _ in self.finished_sub_slots:
            if sub_slot is not None:
                cc_hash = sub_slot.challenge_chain.get_hash()
            else:
                cc_hash = self.constants.GENESIS_CHALLENGE

            if cc_hash == challenge_hash:
                if index == 0:
                    return SignagePoint(None, None, None, None)
                sp: Optional[SignagePoint] = sps[index]
                if sp is not None:
                    assert sp.rc_vdf is not None
                    if sp.rc_vdf.challenge == last_rc_infusion:
                        return sp
                return None
        return None

    def have_newer_signage_point(self, challenge_hash: bytes32, index: uint8,
                                 last_rc_infusion: bytes32) -> bool:
        """
        Returns true if we have a signage point at this index which is based on a newer infusion.
        """
        assert len(self.finished_sub_slots) >= 1
        for sub_slot, sps, _ in self.finished_sub_slots:
            if sub_slot is not None:
                cc_hash = sub_slot.challenge_chain.get_hash()
            else:
                cc_hash = self.constants.GENESIS_CHALLENGE

            if cc_hash == challenge_hash:
                found_rc_hash = False
                for i in range(0, index):
                    sp: Optional[SignagePoint] = sps[i]
                    if sp is not None and sp.rc_vdf is not None and sp.rc_vdf.challenge == last_rc_infusion:
                        found_rc_hash = True
                sp = sps[index]
                if (found_rc_hash and sp is not None and sp.rc_vdf is not None
                        and sp.rc_vdf.challenge != last_rc_infusion):
                    return True
        return False

    def new_peak(
        self,
        peak: BlockRecord,
        peak_full_block: FullBlock,
        sp_sub_slot: Optional[
            EndOfSubSlotBundle],  # None if not overflow, or in first/second slot
        ip_sub_slot: Optional[EndOfSubSlotBundle],  # None if in first slot
        fork_block: Optional[BlockRecord],
        blocks: BlockchainInterface,
    ) -> Tuple[Optional[EndOfSubSlotBundle], List[Tuple[uint8, SignagePoint]],
               List[timelord_protocol.NewInfusionPointVDF]]:
        """
        If the peak is an overflow block, must provide two sub-slots: one for the current sub-slot and one for
        the prev sub-slot (since we still might get more blocks with an sp in the previous sub-slot)

        Results in either one or two sub-slots in finished_sub_slots.
        """
        assert len(self.finished_sub_slots) >= 1

        if ip_sub_slot is None:
            # We are still in the first sub-slot, no new sub slots ey
            self.initialize_genesis_sub_slot()
        else:
            # This is not the first sub-slot in the chain
            sp_sub_slot_sps: List[Optional[SignagePoint]] = [
                None
            ] * self.constants.NUM_SPS_SUB_SLOT
            ip_sub_slot_sps: List[Optional[SignagePoint]] = [
                None
            ] * self.constants.NUM_SPS_SUB_SLOT

            if fork_block is not None and fork_block.sub_slot_iters != peak.sub_slot_iters:
                # If there was a reorg and a difficulty adjustment, just clear all the slots
                self.clear_slots()
            else:
                interval_iters = calculate_sp_interval_iters(
                    self.constants, peak.sub_slot_iters)
                # If it's not a reorg, or there is a reorg on the same difficulty, we can keep signage points
                # that we had before, in the cache
                for index, (sub_slot, sps,
                            total_iters) in enumerate(self.finished_sub_slots):
                    if sub_slot is None:
                        continue

                    if fork_block is None:
                        # If this is not a reorg, we still want to remove signage points after the new peak
                        fork_block = peak
                    replaced_sps: List[Optional[SignagePoint]] = [
                    ]  # index 0 is the end of sub slot
                    for i, sp in enumerate(sps):
                        if (total_iters +
                                i * interval_iters) < fork_block.total_iters:
                            # Sps before the fork point as still valid
                            replaced_sps.append(sp)
                        else:
                            if sp is not None:
                                log.debug(
                                    f"Reverting {i} {(total_iters + i * interval_iters)} {fork_block.total_iters}"
                                )
                            # Sps after the fork point should be removed
                            replaced_sps.append(None)
                    assert len(sps) == len(replaced_sps)

                    if sub_slot == sp_sub_slot:
                        sp_sub_slot_sps = replaced_sps
                    if sub_slot == ip_sub_slot:
                        ip_sub_slot_sps = replaced_sps

            self.clear_slots()

            prev_sub_slot_total_iters = peak.sp_sub_slot_total_iters(
                self.constants)
            if sp_sub_slot is not None or prev_sub_slot_total_iters == 0:
                assert peak.overflow or prev_sub_slot_total_iters
                self.finished_sub_slots.append(
                    (sp_sub_slot, sp_sub_slot_sps, prev_sub_slot_total_iters))

            ip_sub_slot_total_iters = peak.ip_sub_slot_total_iters(
                self.constants)
            self.finished_sub_slots.append(
                (ip_sub_slot, ip_sub_slot_sps, ip_sub_slot_total_iters))

        new_eos: Optional[EndOfSubSlotBundle] = None
        new_sps: List[Tuple[uint8, SignagePoint]] = []
        new_ips: List[timelord_protocol.NewInfusionPointVDF] = []

        future_eos: List[EndOfSubSlotBundle] = self.future_eos_cache.get(
            peak.reward_infusion_new_challenge, []).copy()
        for eos in future_eos:
            if self.new_finished_sub_slot(eos, blocks, peak,
                                          peak_full_block) is not None:
                new_eos = eos
                break

        future_sps: List[Tuple[uint8,
                               SignagePoint]] = self.future_sp_cache.get(
                                   peak.reward_infusion_new_challenge,
                                   []).copy()
        for index, sp in future_sps:
            assert sp.cc_vdf is not None
            if self.new_signage_point(index, blocks, peak, peak.sub_slot_iters,
                                      sp):
                new_sps.append((index, sp))

        for ip in self.future_ip_cache.get(peak.reward_infusion_new_challenge,
                                           []):
            new_ips.append(ip)

        self.future_eos_cache.pop(peak.reward_infusion_new_challenge, [])
        self.future_sp_cache.pop(peak.reward_infusion_new_challenge, [])
        self.future_ip_cache.pop(peak.reward_infusion_new_challenge, [])

        for eos_op, _, _ in self.finished_sub_slots:
            if eos_op is not None:
                self.recent_eos.put(eos_op.challenge_chain.get_hash(),
                                    (eos_op, time.time()))

        return new_eos, new_sps, new_ips

    def get_finished_sub_slots(
        self,
        block_records: BlockchainInterface,
        prev_b: Optional[BlockRecord],
        last_challenge_to_add: bytes32,
    ) -> Optional[List[EndOfSubSlotBundle]]:
        """
        Retrieves the EndOfSubSlotBundles that are in the store either:
        1. From the starting challenge if prev_b is None
        2. That are not included in the blockchain with peak of prev_b if prev_b is not None

        Stops at last_challenge
        """

        if prev_b is None:
            # The first sub slot must be None
            assert self.finished_sub_slots[0][0] is None
            challenge_in_chain: bytes32 = self.constants.GENESIS_CHALLENGE
        else:
            curr: BlockRecord = prev_b
            while not curr.first_in_sub_slot:
                curr = block_records.block_record(curr.prev_hash)
            assert curr is not None
            assert curr.finished_challenge_slot_hashes is not None
            challenge_in_chain = curr.finished_challenge_slot_hashes[-1]

        if last_challenge_to_add == challenge_in_chain:
            # No additional slots to add
            return []

        collected_sub_slots: List[EndOfSubSlotBundle] = []
        found_last_challenge = False
        found_connecting_challenge = False
        for sub_slot, sps, total_iters in self.finished_sub_slots[1:]:
            assert sub_slot is not None
            if sub_slot.challenge_chain.challenge_chain_end_of_slot_vdf.challenge == challenge_in_chain:
                found_connecting_challenge = True
            if found_connecting_challenge:
                collected_sub_slots.append(sub_slot)
            if found_connecting_challenge and sub_slot.challenge_chain.get_hash(
            ) == last_challenge_to_add:
                found_last_challenge = True
                break
        if not found_last_challenge:
            log.warning(
                f"Did not find hash {last_challenge_to_add} connected to "
                f"{challenge_in_chain}")
            return None
        return collected_sub_slots
Esempio n. 15
0
class CoinStore:
    """
    This object handles CoinRecords in DB.
    A cache is maintained for quicker access to recent coins.
    """

    coin_record_db: aiosqlite.Connection
    coin_record_cache: LRUCache
    cache_size: uint32
    db_wrapper: DBWrapper

    @classmethod
    async def create(cls, db_wrapper: DBWrapper, cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db
        await self.coin_record_db.execute("pragma journal_mode=wal")
        await self.coin_record_db.execute("pragma synchronous=2")
        await self.coin_record_db.execute(
            (
                "CREATE TABLE IF NOT EXISTS coin_record("
                "coin_name text PRIMARY KEY,"
                " confirmed_index bigint,"
                " spent_index bigint,"
                " spent int,"
                " coinbase int,"
                " puzzle_hash text,"
                " coin_parent text,"
                " amount blob,"
                " timestamp bigint)"
            )
        )

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)")

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent on coin_record(spent)")

        await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)")

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self

    async def new_block(self, block: FullBlock, tx_additions: List[Coin], tx_removals: List[bytes32]):
        """
        Only called for blocks which are blocks (and thus have rewards and transactions)
        """
        if block.is_transaction_block() is False:
            return None
        assert block.foliage_transaction_block is not None

        for coin in tx_additions:
            record: CoinRecord = CoinRecord(
                coin,
                block.height,
                uint32(0),
                False,
                False,
                block.foliage_transaction_block.timestamp,
            )
            await self._add_coin_record(record, False)

        included_reward_coins = block.get_included_reward_coins()
        if block.height == 0:
            assert len(included_reward_coins) == 0
        else:
            assert len(included_reward_coins) >= 2

        for coin in included_reward_coins:
            reward_coin_r: CoinRecord = CoinRecord(
                coin,
                block.height,
                uint32(0),
                False,
                True,
                block.foliage_transaction_block.timestamp,
            )
            await self._add_coin_record(reward_coin_r, False)

        total_amount_spent: int = 0
        for coin_name in tx_removals:
            total_amount_spent += await self._set_spent(coin_name, block.height)

        # Sanity check, already checked in block_body_validation
        assert sum([a.amount for a in tx_additions]) <= total_amount_spent

    # Checks DB and DiffStores for CoinRecord with coin_name and returns it
    async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]:
        cached = self.coin_record_cache.get(coin_name)
        if cached is not None:
            return cached
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            record = CoinRecord(coin, row[1], row[2], row[3], row[4], row[8])
            self.coin_record_cache.put(record.coin.name(), record)
            return record
        return None

    async def get_coins_added_at_height(self, height: uint32) -> List[CoinRecord]:
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE confirmed_index=?", (height,))
        rows = await cursor.fetchall()
        await cursor.close()
        coins = []
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.append(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return coins

    async def get_coins_removed_at_height(self, height: uint32) -> List[CoinRecord]:
        cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE spent_index=?", (height,))
        rows = await cursor.fetchall()
        await cursor.close()
        coins = []
        for row in rows:
            spent: bool = bool(row[3])
            if spent:
                coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
                coin_record = CoinRecord(coin, row[1], row[2], spent, row[4], row[8])
                coins.append(coin_record)
        return coins

    # Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
    async def get_coin_records_by_puzzle_hash(
        self,
        include_spent_coins: bool,
        puzzle_hash: bytes32,
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:

        coins = set()
        cursor = await self.coin_record_db.execute(
            f"SELECT * from coin_record INDEXED BY coin_puzzle_hash WHERE puzzle_hash=? "
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            (puzzle_hash.hex(), start_height, end_height),
        )
        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_puzzle_hashes(
        self,
        include_spent_coins: bool,
        puzzle_hashes: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        cursor = await self.coin_record_db.execute(
            f"SELECT * from coin_record INDEXED BY coin_puzzle_hash "
            f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            puzzle_hashes_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_names(
        self,
        include_spent_coins: bool,
        names: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(names) == 0:
            return []

        coins = set()
        names_db = tuple([name.hex() for name in names])
        cursor = await self.coin_record_db.execute(
            f'SELECT * from coin_record WHERE coin_name in ({"?," * (len(names_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            names_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def get_coin_records_by_parent_ids(
        self,
        include_spent_coins: bool,
        parent_ids: List[bytes32],
        start_height: uint32 = uint32(0),
        end_height: uint32 = uint32((2 ** 32) - 1),
    ) -> List[CoinRecord]:
        if len(parent_ids) == 0:
            return []

        coins = set()
        parent_ids_db = tuple([pid.hex() for pid in parent_ids])
        cursor = await self.coin_record_db.execute(
            f'SELECT * from coin_record WHERE coin_parent in ({"?," * (len(parent_ids_db) - 1)}?) '
            f"AND confirmed_index>=? AND confirmed_index<? "
            f"{'' if include_spent_coins else 'AND spent=0'}",
            parent_ids_db + (start_height, end_height),
        )

        rows = await cursor.fetchall()

        await cursor.close()
        for row in rows:
            coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
            coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]))
        return list(coins)

    async def rollback_to_block(self, block_index: int):
        """
        Note that block_index can be negative, in which case everything is rolled back
        """
        # Update memory cache
        delete_queue: bytes32 = []
        for coin_name, coin_record in list(self.coin_record_cache.cache.items()):
            if int(coin_record.spent_block_index) > block_index:
                new_record = CoinRecord(
                    coin_record.coin,
                    coin_record.confirmed_block_index,
                    uint32(0),
                    False,
                    coin_record.coinbase,
                    coin_record.timestamp,
                )
                self.coin_record_cache.put(coin_record.coin.name(), new_record)
            if int(coin_record.confirmed_block_index) > block_index:
                delete_queue.append(coin_name)

        for coin_name in delete_queue:
            self.coin_record_cache.remove(coin_name)

        # Delete from storage
        c1 = await self.coin_record_db.execute("DELETE FROM coin_record WHERE confirmed_index>?", (block_index,))
        await c1.close()
        c2 = await self.coin_record_db.execute(
            "UPDATE coin_record SET spent_index = 0, spent = 0 WHERE spent_index>?",
            (block_index,),
        )
        await c2.close()

    # Store CoinRecord in DB and ram cache
    async def _add_coin_record(self, record: CoinRecord, allow_replace: bool) -> None:
        if self.coin_record_cache.get(record.coin.name()) is not None:
            self.coin_record_cache.remove(record.coin.name())

        cursor = await self.coin_record_db.execute(
            f"INSERT {'OR REPLACE ' if allow_replace else ''}INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (
                record.coin.name().hex(),
                record.confirmed_block_index,
                record.spent_block_index,
                int(record.spent),
                int(record.coinbase),
                str(record.coin.puzzle_hash.hex()),
                str(record.coin.parent_coin_info.hex()),
                bytes(record.coin.amount),
                record.timestamp,
            ),
        )
        await cursor.close()

    # Update coin_record to be spent in DB
    async def _set_spent(self, coin_name: bytes32, index: uint32) -> uint64:
        current: Optional[CoinRecord] = await self.get_coin_record(coin_name)
        if current is None:
            raise ValueError(f"Cannot spend a coin that does not exist in db: {coin_name}")

        assert not current.spent  # Redundant sanity check, already checked in block_body_validation
        spent: CoinRecord = CoinRecord(
            current.coin,
            current.confirmed_block_index,
            index,
            True,
            current.coinbase,
            current.timestamp,
        )  # type: ignore # noqa
        await self._add_coin_record(spent, True)
        return current.coin.amount
Esempio n. 16
0
    def __init__(self, private_key: PrivateKey, config: Dict,
                 constants: ConsensusConstants):
        self.log = logging
        # If you want to log to a file: use filename='example.log', encoding='utf-8'
        self.log.basicConfig(level=logging.INFO)

        self.private_key = private_key
        self.public_key: G1Element = private_key.get_g1()
        self.config = config
        self.constants = constants
        self.node_rpc_client = None
        self.wallet_rpc_client = None

        self.store: Optional[PoolStore] = None

        self.pool_fee = 0.01

        # This number should be held constant and be consistent for every pool in the network. DO NOT CHANGE
        self.iters_limit = self.constants.POOL_SUB_SLOT_ITERS // 64

        # This number should not be changed, since users will put this into their singletons
        self.relative_lock_height = uint32(100)

        # TODO(pool): potentially tweak these numbers for security and performance
        self.pool_url = "https://myreferencepool.com"
        self.min_difficulty = uint64(
            10)  # 10 difficulty is about 1 proof a day per plot
        self.default_difficulty: uint64 = uint64(10)

        self.pending_point_partials: Optional[asyncio.Queue] = None
        self.recent_points_added: LRUCache = LRUCache(20000)

        # This is where the block rewards will get paid out to. The pool needs to support this address forever,
        # since the farmers will encode it into their singleton on the blockchain.

        self.default_pool_puzzle_hash: bytes32 = bytes32(
            decode_puzzle_hash(
                "xch12ma5m7sezasgh95wkyr8470ngryec27jxcvxcmsmc4ghy7c4njssnn623q"
            ))

        # The pool fees will be sent to this address
        self.pool_fee_puzzle_hash: bytes32 = bytes32(
            decode_puzzle_hash(
                "txch1h8ggpvqzhrquuchquk7s970cy0m0e0yxd4hxqwzqkpzxk9jx9nzqmd67ux"
            ))

        # This is the wallet fingerprint and ID for the wallet spending the funds from `self.default_pool_puzzle_hash`
        self.wallet_fingerprint = 2938470744
        self.wallet_id = "1"

        # We need to check for slow farmers. If farmers cannot submit proofs in time, they won't be able to win
        # any rewards either. This number can be tweaked to be more or less strict. More strict ensures everyone
        # gets high rewards, but it might cause some of the slower farmers to not be able to participate in the pool.
        self.partial_time_limit: int = 25

        # There is always a risk of a reorg, in which case we cannot reward farmers that submitted partials in that
        # reorg. That is why we have a time delay before changing any account points.
        self.partial_confirmation_delay: int = 30

        # Keeps track of when each farmer last changed their difficulty, to rate limit how often they can change it
        # This helps when the farmer is farming from two machines at the same time (with conflicting difficulties)
        self.difficulty_change_time: Dict[bytes32, uint64] = {}

        # These are the phs that we want to look for on chain, that we can claim to our pool
        self.scan_p2_singleton_puzzle_hashes: Set[bytes32] = set()

        # Don't scan anything before this height, for efficiency (for example pool start date)
        self.scan_start_height: uint32 = uint32(1000)

        # Interval for scanning and collecting the pool rewards
        self.collect_pool_rewards_interval = 600

        # After this many confirmations, a transaction is considered final and irreversible
        self.confirmation_security_threshold = 6

        # Interval for making payout transactions to farmers
        self.payment_interval = 600

        # We will not make transactions with more targets than this, to ensure our transaction gets into the blockchain
        # faster.
        self.max_additions_per_transaction = 400

        # This is the list of payments that we have not sent yet, to farmers
        self.pending_payments: Optional[asyncio.Queue] = None

        # Keeps track of the latest state of our node
        self.blockchain_state = {"peak": None}

        # Whether or not the wallet is synced (required to make payments)
        self.wallet_synced = False

        # Tasks (infinite While loops) for different purposes
        self.confirm_partials_loop_task: Optional[asyncio.Task] = None
        self.collect_pool_rewards_loop_task: Optional[asyncio.Task] = None
        self.create_payment_loop_task: Optional[asyncio.Task] = None
        self.submit_payment_loop_task: Optional[asyncio.Task] = None
        self.get_peak_loop_task: Optional[asyncio.Task] = None

        self.node_rpc_client: Optional[FullNodeRpcClient] = None
        self.wallet_rpc_client: Optional[WalletRpcClient] = None
Esempio n. 17
0
class CoinStore:
    """
    This object handles CoinRecords in DB.
    A cache is maintained for quicker access to recent coins.
    """

    coin_record_db: aiosqlite.Connection
    coin_record_cache: LRUCache
    cache_size: uint32
    db_wrapper: DBWrapper

    @classmethod
    async def create(cls,
                     db_wrapper: DBWrapper,
                     cache_size: uint32 = uint32(60000)):
        self = cls()

        self.cache_size = cache_size
        self.db_wrapper = db_wrapper
        self.coin_record_db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                "CREATE TABLE IF NOT EXISTS coin_record("
                "coin_name blob PRIMARY KEY,"
                " confirmed_index bigint,"
                " spent_index bigint,"  # if this is zero, it means the coin has not been spent
                " coinbase int,"
                " puzzle_hash blob,"
                " coin_parent blob,"
                " amount blob,"  # we use a blob of 8 bytes to store uint64
                " timestamp bigint)")

        else:

            # the coin_name is unique in this table because the CoinStore always
            # only represent a single peak
            await self.coin_record_db.execute(
                ("CREATE TABLE IF NOT EXISTS coin_record("
                 "coin_name text PRIMARY KEY,"
                 " confirmed_index bigint,"
                 " spent_index bigint,"
                 " spent int,"
                 " coinbase int,"
                 " puzzle_hash text,"
                 " coin_parent text,"
                 " amount blob,"
                 " timestamp bigint)"))

        # Useful for reorg lookups
        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_confirmed_index on coin_record(confirmed_index)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_spent_index on coin_record(spent_index)"
        )

        if self.db_wrapper.allow_upgrades:
            await self.coin_record_db.execute("DROP INDEX IF EXISTS coin_spent"
                                              )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)"
        )

        await self.coin_record_db.execute(
            "CREATE INDEX IF NOT EXISTS coin_parent_index on coin_record(coin_parent)"
        )

        await self.coin_record_db.commit()
        self.coin_record_cache = LRUCache(cache_size)
        return self

    def maybe_from_hex(self, field: Any) -> bytes:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return bytes.fromhex(field)

    def maybe_to_hex(self, field: bytes) -> Any:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return field.hex()

    async def new_block(
        self,
        height: uint32,
        timestamp: uint64,
        included_reward_coins: Set[Coin],
        tx_additions: List[Coin],
        tx_removals: List[bytes32],
    ) -> List[CoinRecord]:
        """
        Only called for blocks which are blocks (and thus have rewards and transactions)
        Returns a list of the CoinRecords that were added by this block
        """

        start = time()

        additions = []

        for coin in tx_additions:
            record: CoinRecord = CoinRecord(
                coin,
                height,
                uint32(0),
                False,
                timestamp,
            )
            additions.append(record)

        if height == 0:
            assert len(included_reward_coins) == 0
        else:
            assert len(included_reward_coins) >= 2

        for coin in included_reward_coins:
            reward_coin_r: CoinRecord = CoinRecord(
                coin,
                height,
                uint32(0),
                True,
                timestamp,
            )
            additions.append(reward_coin_r)

        await self._add_coin_records(additions)
        await self._set_spent(tx_removals, height)

        end = time()
        log.log(
            logging.WARNING if end - start > 10 else logging.DEBUG,
            f"It took {end - start:0.2f}s to apply {len(tx_additions)} additions and "
            + f"{len(tx_removals)} removals to the coin store. Make sure " +
            "blockchain database is on a fast drive",
        )

        return additions

    # Checks DB and DiffStores for CoinRecord with coin_name and returns it
    async def get_coin_record(self,
                              coin_name: bytes32) -> Optional[CoinRecord]:
        cached = self.coin_record_cache.get(coin_name)
        if cached is not None:
            return cached

        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE coin_name=?",
            (self.maybe_to_hex(coin_name), ),
        ) as cursor:
            row = await cursor.fetchone()
            if row is not None:
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, row[0], row[1], row[2], row[6])
                self.coin_record_cache.put(record.coin.name(), record)
                return record
        return None

    async def get_coins_added_at_height(self,
                                        height: uint32) -> List[CoinRecord]:
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index=?",
            (height, ),
        ) as cursor:
            rows = await cursor.fetchall()
            coins = []
            for row in rows:
                coin = self.row_to_coin(row)
                coins.append(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return coins

    async def get_coins_removed_at_height(self,
                                          height: uint32) -> List[CoinRecord]:
        # Special case to avoid querying all unspent coins (spent_index=0)
        if height == 0:
            return []
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE spent_index=?",
            (height, ),
        ) as cursor:
            coins = []
            for row in await cursor.fetchall():
                if row[1] != 0:
                    coin = self.row_to_coin(row)
                    coin_record = CoinRecord(coin, row[0], row[1], row[2],
                                             row[6])
                    coins.append(coin_record)
            return coins

    # Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
    async def get_coin_records_by_puzzle_hash(
            self,
            include_spent_coins: bool,
            puzzle_hash: bytes32,
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:

        coins = set()

        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash WHERE puzzle_hash=? "
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
            (self.maybe_to_hex(puzzle_hash), start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_records_by_puzzle_hashes(
            self,
            include_spent_coins: bool,
            puzzle_hashes: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            puzzle_hashes_db = tuple(puzzle_hashes)
        else:
            puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
                f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                puzzle_hashes_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_records_by_names(
            self,
            include_spent_coins: bool,
            names: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(names) == 0:
            return []

        coins = set()
        names_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            names_db = tuple(names)
        else:
            names_db = tuple([name.hex() for name in names])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_name in ({"?," * (len(names) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                names_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))

        return list(coins)

    def row_to_coin(self, row) -> Coin:
        return Coin(bytes32(self.maybe_from_hex(row[4])),
                    bytes32(self.maybe_from_hex(row[3])),
                    uint64.from_bytes(row[5]))

    def row_to_coin_state(self, row):
        coin = self.row_to_coin(row)
        spent_h = None
        if row[1] != 0:
            spent_h = row[1]
        return CoinState(coin, spent_h, row[0])

    async def get_coin_states_by_puzzle_hashes(
            self,
            include_spent_coins: bool,
            puzzle_hashes: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinState]:
        if len(puzzle_hashes) == 0:
            return []

        coins = set()
        puzzle_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            puzzle_hashes_db = tuple(puzzle_hashes)
        else:
            puzzle_hashes_db = tuple([ph.hex() for ph in puzzle_hashes])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
                f'WHERE puzzle_hash in ({"?," * (len(puzzle_hashes) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                puzzle_hashes_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coins.add(self.row_to_coin_state(row))

            return list(coins)

    async def get_coin_records_by_parent_ids(
            self,
            include_spent_coins: bool,
            parent_ids: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinRecord]:
        if len(parent_ids) == 0:
            return []

        coins = set()
        parent_ids_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            parent_ids_db = tuple(parent_ids)
        else:
            parent_ids_db = tuple([pid.hex() for pid in parent_ids])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_parent in ({"?," * (len(parent_ids) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                parent_ids_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
            return list(coins)

    async def get_coin_state_by_ids(
            self,
            include_spent_coins: bool,
            coin_ids: List[bytes32],
            start_height: uint32 = uint32(0),
            end_height: uint32 = uint32((2**32) - 1),
    ) -> List[CoinState]:
        if len(coin_ids) == 0:
            return []

        coins = set()
        coin_ids_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            coin_ids_db = tuple(coin_ids)
        else:
            coin_ids_db = tuple([pid.hex() for pid in coin_ids])
        async with self.coin_record_db.execute(
                f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                f'coin_parent, amount, timestamp FROM coin_record WHERE coin_name in ({"?," * (len(coin_ids) - 1)}?) '
                f"AND confirmed_index>=? AND confirmed_index<? "
                f"{'' if include_spent_coins else 'AND spent_index=0'}",
                coin_ids_db + (start_height, end_height),
        ) as cursor:

            for row in await cursor.fetchall():
                coins.add(self.row_to_coin_state(row))
            return list(coins)

    async def rollback_to_block(self, block_index: int) -> List[CoinRecord]:
        """
        Note that block_index can be negative, in which case everything is rolled back
        Returns the list of coin records that have been modified
        """
        # Update memory cache
        delete_queue: List[bytes32] = []
        for coin_name, coin_record in list(
                self.coin_record_cache.cache.items()):
            if int(coin_record.spent_block_index) > block_index:
                new_record = CoinRecord(
                    coin_record.coin,
                    coin_record.confirmed_block_index,
                    uint32(0),
                    coin_record.coinbase,
                    coin_record.timestamp,
                )
                self.coin_record_cache.put(coin_record.coin.name(), new_record)
            if int(coin_record.confirmed_block_index) > block_index:
                delete_queue.append(coin_name)

        for coin_name in delete_queue:
            self.coin_record_cache.remove(coin_name)

        coin_changes: Dict[bytes32, CoinRecord] = {}
        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index>?",
            (block_index, ),
        ) as cursor:
            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, uint32(0), row[1], row[2], uint64(0))
                coin_changes[record.name] = record

        # Delete from storage
        await self.coin_record_db.execute(
            "DELETE FROM coin_record WHERE confirmed_index>?", (block_index, ))

        async with self.coin_record_db.execute(
                "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
                "coin_parent, amount, timestamp FROM coin_record WHERE confirmed_index>?",
            (block_index, ),
        ) as cursor:
            for row in await cursor.fetchall():
                coin = self.row_to_coin(row)
                record = CoinRecord(coin, row[0], uint32(0), row[2], row[6])
                if record.name not in coin_changes:
                    coin_changes[record.name] = record

        if self.db_wrapper.db_version == 2:
            await self.coin_record_db.execute(
                "UPDATE coin_record SET spent_index=0 WHERE spent_index>?",
                (block_index, ))
        else:
            await self.coin_record_db.execute(
                "UPDATE coin_record SET spent_index = 0, spent = 0 WHERE spent_index>?",
                (block_index, ))
        return list(coin_changes.values())

    # Store CoinRecord in DB and ram cache
    async def _add_coin_records(self, records: List[CoinRecord]) -> None:

        if self.db_wrapper.db_version == 2:
            values2 = []
            for record in records:
                self.coin_record_cache.put(record.coin.name(), record)
                values2.append((
                    record.coin.name(),
                    record.confirmed_block_index,
                    record.spent_block_index,
                    int(record.coinbase),
                    record.coin.puzzle_hash,
                    record.coin.parent_coin_info,
                    bytes(record.coin.amount),
                    record.timestamp,
                ))
            await self.coin_record_db.executemany(
                "INSERT INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
                values2,
            )
        else:
            values = []
            for record in records:
                self.coin_record_cache.put(record.coin.name(), record)
                values.append((
                    record.coin.name().hex(),
                    record.confirmed_block_index,
                    record.spent_block_index,
                    int(record.spent),
                    int(record.coinbase),
                    record.coin.puzzle_hash.hex(),
                    record.coin.parent_coin_info.hex(),
                    bytes(record.coin.amount),
                    record.timestamp,
                ))
            await self.coin_record_db.executemany(
                "INSERT INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)",
                values,
            )

    # Update coin_record to be spent in DB
    async def _set_spent(self, coin_names: List[bytes32], index: uint32):

        assert len(coin_names) == 0 or index > 0
        # if this coin is in the cache, mark it as spent in there
        updates = []
        for coin_name in coin_names:
            r = self.coin_record_cache.get(coin_name)
            if r is not None:
                self.coin_record_cache.put(
                    r.name,
                    CoinRecord(r.coin, r.confirmed_block_index, index,
                               r.coinbase, r.timestamp))
            updates.append((index, self.maybe_to_hex(coin_name)))

        if self.db_wrapper.db_version == 2:
            await self.coin_record_db.executemany(
                "UPDATE OR FAIL coin_record SET spent_index=? WHERE coin_name=?",
                updates)
        else:
            await self.coin_record_db.executemany(
                "UPDATE OR FAIL coin_record SET spent=1,spent_index=? WHERE coin_name=?",
                updates)
Esempio n. 18
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache
    db_wrapper: DBWrapper
    ses_challenge_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute("pragma journal_mode=wal")
        await self.db.execute("pragma synchronous=2")
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
            "  is_block tinyint, is_fully_compactified tinyint, block blob)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint,"
            "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
        )

        # todo remove in v1.2
        await self.db.execute("DROP TABLE IF EXISTS sub_epoch_segments_v2")

        # Sub epoch segments for weight proofs
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY, challenge_segments blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)"
        )
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on full_blocks(is_block)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS is_block on block_records(is_block)")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self

    async def add_full_block(self, header_hash: bytes32, block: FullBlock,
                             block_record: BlockRecord) -> None:
        self.block_cache.put(header_hash, block)
        cursor_1 = await self.db.execute(
            "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
            (
                header_hash.hex(),
                block.height,
                int(block.is_transaction_block()),
                int(block.is_fully_compactified()),
                bytes(block),
            ),
        )

        await cursor_1.close()

        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
            (
                header_hash.hex(),
                block.prev_header_hash.hex(),
                block.height,
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
                block.is_transaction_block(),
            ),
        )
        await cursor_2.close()

    async def persist_sub_epoch_challenge_segments(
            self, ses_block_hash: bytes32,
            segments: List[SubEpochChallengeSegment]) -> None:
        async with self.db_wrapper.lock:
            cursor_1 = await self.db.execute(
                "INSERT OR REPLACE INTO sub_epoch_segments_v3 VALUES(?, ?)",
                (ses_block_hash.hex(), bytes(SubEpochSegments(segments))),
            )
            await cursor_1.close()
            await self.db.commit()

    async def get_sub_epoch_challenge_segments(
        self,
        ses_block_hash: bytes32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cached = self.ses_challenge_cache.get(ses_block_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v3 WHERE ses_block_hash=?",
            (ses_block_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            challenge_segments = SubEpochSegments.from_bytes(
                row[0]).challenge_segments
            self.ses_challenge_cache.put(ses_block_hash, challenge_segments)
            return challenge_segments
        return None

    def rollback_cache_block(self, header_hash: bytes32):
        try:
            self.block_cache.remove(header_hash)
        except KeyError:
            # this is best effort. When rolling back, we may not have added the
            # block to the cache yet
            pass

    async def get_full_block(self,
                             header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return cached
        log.debug(f"cache miss for block {header_hash.hex()}")
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            block = FullBlock.from_bytes(row[0])
            self.block_cache.put(header_hash, block)
            return block
        return None

    async def get_full_block_bytes(self,
                                   header_hash: bytes32) -> Optional[bytes]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return bytes(cached)
        log.debug(f"cache miss for block {header_hash.hex()}")
        cursor = await self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return row[0]
        return None

    async def get_full_blocks_at(self,
                                 heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [FullBlock.from_bytes(row[0]) for row in rows]

    async def get_block_records_by_hash(self, header_hashes: List[bytes32]):
        """
        Returns a list of Block Records, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """
        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = f'SELECT block from block_records WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            block_rec: BlockRecord = BlockRecord.from_bytes(row[0])
            all_blocks[block_rec.header_hash] = block_rec
        ret: List[BlockRecord] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_blocks_by_hash(
            self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = (
            f'SELECT header_hash, block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        )
        cursor = await self.db.execute(formatted_str, header_hashes_db)
        rows = await cursor.fetchall()
        await cursor.close()
        all_blocks: Dict[bytes32, FullBlock] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            full_block: FullBlock = FullBlock.from_bytes(row[1])
            all_blocks[header_hash] = full_block
            self.block_cache.put(header_hash, full_block)
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        peak_row = await res.fetchone()
        await res.close()
        if peak_row is None:
            return {}, None

        formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak_row[2] - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash = bytes.fromhex(row[0])
            ret[header_hash] = BlockRecord.from_bytes(row[1])
        return ret, bytes.fromhex(peak_row[0])

    async def get_peak_height_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT * from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def is_fully_compactified(self,
                                    header_hash: bytes32) -> Optional[bool]:
        cursor = await self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is None:
            return None
        return bool(row[0])

    async def get_random_not_compactified(self, number: int) -> List[int]:
        # Since orphan blocks do not get compactified, we need to check whether all blocks with a
        # certain height are not compact. And if we do have compact orphan blocks, then all that
        # happens is that the occasional chain block stays uncompact - not ideal, but harmless.
        cursor = await self.db.execute(
            f"SELECT height FROM full_blocks GROUP BY height HAVING sum(is_fully_compactified)=0 "
            f"ORDER BY RANDOM() LIMIT {number}")
        rows = await cursor.fetchall()
        await cursor.close()

        heights = []
        for row in rows:
            heights.append(int(row[0]))

        return heights
Esempio n. 19
0
    def __init__(self, private_key: PrivateKey, config: Dict,
                 constants: ConsensusConstants):
        self.follow_singleton_tasks: Dict[bytes32, asyncio.Task] = {}
        self.log = logging
        # If you want to log to a file: use filename='example.log', encoding='utf-8'
        self.log.basicConfig(level=logging.INFO)

        # We load our configurations from here
        with open(os.getcwd() + "/config.yaml") as f:
            pool_config: Dict = yaml.safe_load(f)

        # Set our pool info here
        self.info_default_res = pool_config["pool_info"]["default_res"]
        self.info_name = pool_config["pool_info"]["name"]
        self.info_logo_url = pool_config["pool_info"]["logo_url"]
        self.info_description = pool_config["pool_info"]["description"]
        self.welcome_message = pool_config["welcome_message"]

        self.private_key = private_key
        self.public_key: G1Element = private_key.get_g1()
        self.config = config
        self.constants = constants
        self.node_rpc_client = None
        self.wallet_rpc_client = None

        self.store: Optional[PoolStore] = None

        self.pool_fee = pool_config["pool_fee"]

        # This number should be held constant and be consistent for every pool in the network. DO NOT CHANGE
        self.iters_limit = self.constants.POOL_SUB_SLOT_ITERS // 64

        # This number should not be changed, since users will put this into their singletons
        self.relative_lock_height = uint32(100)

        # TODO(pool): potentially tweak these numbers for security and performance
        # This is what the user enters into the input field. This exact value will be stored on the blockchain
        self.pool_url = pool_config["pool_url"]
        self.min_difficulty = uint64(
            pool_config["min_difficulty"]
        )  # 10 difficulty is about 1 proof a day per plot
        self.default_difficulty: uint64 = uint64(
            pool_config["default_difficulty"])

        self.pending_point_partials: Optional[asyncio.Queue] = None
        self.recent_points_added: LRUCache = LRUCache(20000)

        # The time in minutes for an authentication token to be valid. See "Farmer authentication" in SPECIFICATION.md
        self.authentication_token_timeout: uint8 = pool_config[
            "authentication_token_timeout"]

        # This is where the block rewards will get paid out to. The pool needs to support this address forever,
        # since the farmers will encode it into their singleton on the blockchain. WARNING: the default pool code
        # completely spends this wallet and distributes it to users, do don't put any additional funds in here
        # that you do not want to distribute. Even if the funds are in a different address than this one, they WILL
        # be spent by this code! So only put funds that you want to distribute to pool members here.

        # Using 2164248527
        self.default_target_puzzle_hash: bytes32 = bytes32(
            decode_puzzle_hash(pool_config["default_target_address"]))

        # The pool fees will be sent to this address. This MUST be on a different key than the target_puzzle_hash,
        # otherwise, the fees will be sent to the users. Using 690783650
        self.pool_fee_puzzle_hash: bytes32 = bytes32(
            decode_puzzle_hash(pool_config["pool_fee_address"]))

        # This is the wallet fingerprint and ID for the wallet spending the funds from `self.default_target_puzzle_hash`
        self.wallet_fingerprint = pool_config["wallet_fingerprint"]
        self.wallet_id = pool_config["wallet_id"]

        # We need to check for slow farmers. If farmers cannot submit proofs in time, they won't be able to win
        # any rewards either. This number can be tweaked to be more or less strict. More strict ensures everyone
        # gets high rewards, but it might cause some of the slower farmers to not be able to participate in the pool.
        self.partial_time_limit: int = pool_config["partial_time_limit"]

        # There is always a risk of a reorg, in which case we cannot reward farmers that submitted partials in that
        # reorg. That is why we have a time delay before changing any account points.
        self.partial_confirmation_delay: int = pool_config[
            "partial_confirmation_delay"]

        # These are the phs that we want to look for on chain, that we can claim to our pool
        self.scan_p2_singleton_puzzle_hashes: Set[bytes32] = set()

        # Don't scan anything before this height, for efficiency (for example pool start date)
        self.scan_start_height: uint32 = uint32(
            pool_config["scan_start_height"])

        # Interval for scanning and collecting the pool rewards
        self.collect_pool_rewards_interval = pool_config[
            "collect_pool_rewards_interval"]

        # After this many confirmations, a transaction is considered final and irreversible
        self.confirmation_security_threshold = pool_config[
            "confirmation_security_threshold"]

        # Interval for making payout transactions to farmers
        self.payment_interval = pool_config["payment_interval"]

        # We will not make transactions with more targets than this, to ensure our transaction gets into the blockchain
        # faster.
        self.max_additions_per_transaction = pool_config[
            "max_additions_per_transaction"]

        # This is the list of payments that we have not sent yet, to farmers
        self.pending_payments: Optional[asyncio.Queue] = None

        # Keeps track of the latest state of our node
        self.blockchain_state = {"peak": None}

        # Whether or not the wallet is synced (required to make payments)
        self.wallet_synced = False

        # We target these many partials for this number of seconds. We adjust after receiving this many partials.
        self.number_of_partials_target: int = pool_config[
            "number_of_partials_target"]
        self.time_target: int = pool_config["time_target"]

        # Tasks (infinite While loops) for different purposes
        self.confirm_partials_loop_task: Optional[asyncio.Task] = None
        self.collect_pool_rewards_loop_task: Optional[asyncio.Task] = None
        self.create_payment_loop_task: Optional[asyncio.Task] = None
        self.submit_payment_loop_task: Optional[asyncio.Task] = None
        self.get_peak_loop_task: Optional[asyncio.Task] = None

        self.node_rpc_client: Optional[FullNodeRpcClient] = None
        self.wallet_rpc_client: Optional[WalletRpcClient] = None
    def test_lru_cache(self):
        cache = LRUCache(5)

        assert cache.get(b"0") is None

        assert len(cache.cache) == 0
        cache.put(b"0", 1)
        assert len(cache.cache) == 1
        assert cache.get(b"0") == 1
        cache.put(b"0", 2)
        cache.put(b"0", 3)
        cache.put(b"0", 4)
        cache.put(b"0", 6)
        assert cache.get(b"0") == 6
        assert len(cache.cache) == 1

        cache.put(b"1", 1)
        assert len(cache.cache) == 2
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        cache.put(b"2", 2)
        assert len(cache.cache) == 3
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        cache.put(b"3", 3)
        assert len(cache.cache) == 4
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        assert cache.get(b"3") == 3
        cache.put(b"4", 4)
        assert len(cache.cache) == 5
        assert cache.get(b"0") == 6
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        assert cache.get(b"4") == 4
        cache.put(b"5", 5)
        assert cache.get(b"5") == 5
        assert len(cache.cache) == 5
        print(cache.cache)
        assert cache.get(b"3") is None  # 3 is least recently used
        assert cache.get(b"1") == 1
        assert cache.get(b"2") == 2
        cache.put(b"7", 7)
        assert len(cache.cache) == 5
        assert cache.get(b"0") is None
        assert cache.get(b"1") == 1
Esempio n. 21
0
class BlockStore:
    db: aiosqlite.Connection
    block_cache: LRUCache
    db_wrapper: DBWrapper
    ses_challenge_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # TODO: most data in block is duplicated in block_record. The only
            # reason for this is that our parsing of a FullBlock is so slow,
            # it's faster to store duplicate data to parse less when we just
            # need the BlockRecord. Once we fix the parsing (and data structure)
            # of FullBlock, this can use less space
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks("
                "header_hash blob PRIMARY KEY,"
                "prev_hash blob,"
                "height bigint,"
                "sub_epoch_summary blob,"
                "is_fully_compactified tinyint,"
                "in_main_chain tinyint,"
                "block blob,"
                "block_record blob)"
            )

            # This is a single-row table containing the hash of the current
            # peak. The "key" field is there to make update statements simple
            await self.db.execute("CREATE TABLE IF NOT EXISTS current_peak(key int PRIMARY KEY, hash blob)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on full_blocks(height)")

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3("
                "ses_block_hash blob PRIMARY KEY,"
                "challenge_segments blob)"
            )

            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified ON"
                " full_blocks(is_fully_compactified, in_main_chain) WHERE in_main_chain=1"
            )
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS main_chain ON full_blocks(height, in_main_chain) WHERE in_main_chain=1"
            )

        else:

            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
                "  is_block tinyint, is_fully_compactified tinyint, block blob)"
            )

            # Block records
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS block_records(header_hash "
                "text PRIMARY KEY, prev_hash text, height bigint,"
                "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
            )

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY,"
                "challenge_segments blob)"
            )

            # Height index so we can look up in order of height for sync purposes
            await self.db.execute("CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)")
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
            )

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on block_records(height)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS peak on block_records(is_peak) where is_peak = 1")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self

    def maybe_from_hex(self, field: Any) -> bytes:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return bytes.fromhex(field)

    def maybe_to_hex(self, field: bytes) -> Any:
        if self.db_wrapper.db_version == 2:
            return field
        else:
            return field.hex()

    def compress(self, block: FullBlock) -> bytes:
        return zstd.compress(bytes(block))

    def maybe_decompress(self, block_bytes: bytes) -> FullBlock:
        if self.db_wrapper.db_version == 2:
            return FullBlock.from_bytes(zstd.decompress(block_bytes))
        else:
            return FullBlock.from_bytes(block_bytes)

    async def rollback(self, height: int) -> None:
        if self.db_wrapper.db_version == 2:
            await self.db.execute(
                "UPDATE OR FAIL full_blocks SET in_main_chain=0 WHERE height>? AND in_main_chain=1", (height,)
            )

    async def set_in_chain(self, header_hashes: List[Tuple[bytes32]]) -> None:
        if self.db_wrapper.db_version == 2:
            await self.db.executemany(
                "UPDATE OR FAIL full_blocks SET in_main_chain=1 WHERE header_hash=?", header_hashes
            )

    async def add_full_block(
        self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord, in_main_chain: bool
    ) -> None:
        self.block_cache.put(header_hash, block)

        if self.db_wrapper.db_version == 2:

            ses: Optional[bytes] = (
                None
                if block_record.sub_epoch_summary_included is None
                else bytes(block_record.sub_epoch_summary_included)
            )

            await self.db.execute(
                "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    header_hash,
                    block.prev_header_hash,
                    block.height,
                    ses,
                    int(block.is_fully_compactified()),
                    in_main_chain,  # in_main_chain
                    self.compress(block),
                    bytes(block_record),
                ),
            )

        else:
            await self.db.execute(
                "INSERT OR REPLACE INTO full_blocks VALUES(?, ?, ?, ?, ?)",
                (
                    header_hash.hex(),
                    block.height,
                    int(block.is_transaction_block()),
                    int(block.is_fully_compactified()),
                    bytes(block),
                ),
            )

            await self.db.execute(
                "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?,?, ?, ?)",
                (
                    header_hash.hex(),
                    block.prev_header_hash.hex(),
                    block.height,
                    bytes(block_record),
                    None
                    if block_record.sub_epoch_summary_included is None
                    else bytes(block_record.sub_epoch_summary_included),
                    False,
                    block.is_transaction_block(),
                ),
            )

    async def persist_sub_epoch_challenge_segments(
        self, ses_block_hash: bytes32, segments: List[SubEpochChallengeSegment]
    ) -> None:
        async with self.db_wrapper.lock:
            await self.db.execute(
                "INSERT OR REPLACE INTO sub_epoch_segments_v3 VALUES(?, ?)",
                (self.maybe_to_hex(ses_block_hash), bytes(SubEpochSegments(segments))),
            )
            await self.db.commit()

    async def get_sub_epoch_challenge_segments(
        self,
        ses_block_hash: bytes32,
    ) -> Optional[List[SubEpochChallengeSegment]]:
        cached = self.ses_challenge_cache.get(ses_block_hash)
        if cached is not None:
            return cached

        async with self.db.execute(
            "SELECT challenge_segments from sub_epoch_segments_v3 WHERE ses_block_hash=?",
            (self.maybe_to_hex(ses_block_hash),),
        ) as cursor:
            row = await cursor.fetchone()

        if row is not None:
            challenge_segments = SubEpochSegments.from_bytes(row[0]).challenge_segments
            self.ses_challenge_cache.put(ses_block_hash, challenge_segments)
            return challenge_segments
        return None

    def rollback_cache_block(self, header_hash: bytes32):
        try:
            self.block_cache.remove(header_hash)
        except KeyError:
            # this is best effort. When rolling back, we may not have added the
            # block to the cache yet
            pass

    async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return cached
        log.debug(f"cache miss for block {header_hash.hex()}")
        async with self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is not None:
            block = self.maybe_decompress(row[0])
            self.block_cache.put(header_hash, block)
            return block
        return None

    async def get_full_block_bytes(self, header_hash: bytes32) -> Optional[bytes]:
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            log.debug(f"cache hit for block {header_hash.hex()}")
            return bytes(cached)
        log.debug(f"cache miss for block {header_hash.hex()}")
        async with self.db.execute(
            "SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is not None:
            if self.db_wrapper.db_version == 2:
                return zstd.decompress(row[0])
            else:
                return row[0]

        return None

    async def get_full_blocks_at(self, heights: List[uint32]) -> List[FullBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from full_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        async with self.db.execute(formatted_str, heights_db) as cursor:
            ret: List[FullBlock] = []
            for row in await cursor.fetchall():
                ret.append(self.maybe_decompress(row[0]))
            return ret

    async def get_block_records_by_hash(self, header_hashes: List[bytes32]):
        """
        Returns a list of Block Records, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """
        if len(header_hashes) == 0:
            return []

        all_blocks: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:
            async with self.db.execute(
                "SELECT header_hash,block_record FROM full_blocks "
                f'WHERE header_hash in ({"?," * (len(header_hashes) - 1)}?)',
                tuple(header_hashes),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    all_blocks[header_hash] = BlockRecord.from_bytes(row[1])
        else:
            formatted_str = f'SELECT block from block_records WHERE header_hash in ({"?," * (len(header_hashes) - 1)}?)'
            async with self.db.execute(formatted_str, tuple([hh.hex() for hh in header_hashes])) as cursor:
                for row in await cursor.fetchall():
                    block_rec: BlockRecord = BlockRecord.from_bytes(row[0])
                    all_blocks[block_rec.header_hash] = block_rec

        ret: List[BlockRecord] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_blocks_by_hash(self, header_hashes: List[bytes32]) -> List[FullBlock]:
        """
        Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
        Throws an exception if the blocks are not present
        """

        if len(header_hashes) == 0:
            return []

        header_hashes_db: Tuple[Any, ...]
        if self.db_wrapper.db_version == 2:
            header_hashes_db = tuple(header_hashes)
        else:
            header_hashes_db = tuple([hh.hex() for hh in header_hashes])
        formatted_str = (
            f'SELECT header_hash, block from full_blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)'
        )
        all_blocks: Dict[bytes32, FullBlock] = {}
        async with self.db.execute(formatted_str, header_hashes_db) as cursor:
            for row in await cursor.fetchall():
                header_hash = self.maybe_from_hex(row[0])
                full_block: FullBlock = self.maybe_decompress(row[1])
                # TODO: address hint error and remove ignore
                #       error: Invalid index type "bytes" for "Dict[bytes32, FullBlock]";
                # expected type "bytes32"  [index]
                all_blocks[header_hash] = full_block  # type: ignore[index]
                self.block_cache.put(header_hash, full_block)
        ret: List[FullBlock] = []
        for hh in header_hashes:
            if hh not in all_blocks:
                raise ValueError(f"Header hash {hh} not in the blockchain")
            ret.append(all_blocks[hh])
        return ret

    async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:

        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT block_record FROM full_blocks WHERE header_hash=?",
                (header_hash,),
            ) as cursor:
                row = await cursor.fetchone()
            if row is not None:
                return BlockRecord.from_bytes(row[0])

        else:
            async with self.db.execute(
                "SELECT block from block_records WHERE header_hash=?",
                (header_hash.hex(),),
            ) as cursor:
                row = await cursor.fetchone()
            if row is not None:
                return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks in range between start and stop
        if present.
        """

        ret: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND height <= ?",
                (start, stop),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        else:

            formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

            async with await self.db.execute(formatted_str) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(self.maybe_from_hex(row[0]))
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret

    async def get_peak(self) -> Optional[Tuple[bytes32, uint32]]:

        if self.db_wrapper.db_version == 2:
            async with self.db.execute("SELECT hash FROM current_peak WHERE key = 0") as cursor:
                peak_row = await cursor.fetchone()
            if peak_row is None:
                return None
            async with self.db.execute("SELECT height FROM full_blocks WHERE header_hash=?", (peak_row[0],)) as cursor:
                peak_height = await cursor.fetchone()
            if peak_height is None:
                return None
            return bytes32(peak_row[0]), uint32(peak_height[0])
        else:
            async with self.db.execute("SELECT header_hash, height from block_records WHERE is_peak = 1") as cursor:
                peak_row = await cursor.fetchone()
            if peak_row is None:
                return None
            return bytes32(bytes.fromhex(peak_row[0])), uint32(peak_row[1])

    async def get_block_records_close_to_peak(
        self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
        peak header hash.
        """

        peak = await self.get_peak()
        if peak is None:
            return {}, None

        ret: Dict[bytes32, BlockRecord] = {}
        if self.db_wrapper.db_version == 2:

            async with self.db.execute(
                "SELECT header_hash, block_record FROM full_blocks WHERE height >= ?",
                (peak[1] - blocks_n,),
            ) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(row[0])
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        else:
            formatted_str = f"SELECT header_hash, block  from block_records WHERE height >= {peak[1] - blocks_n}"
            async with self.db.execute(formatted_str) as cursor:
                for row in await cursor.fetchall():
                    header_hash = bytes32(self.maybe_from_hex(row[0]))
                    ret[header_hash] = BlockRecord.from_bytes(row[1])

        return ret, peak[0]

    async def set_peak(self, header_hash: bytes32) -> None:
        # We need to be in a sqlite transaction here.
        # Note: we do not commit this to the database yet, as we need to also change the coin store

        if self.db_wrapper.db_version == 2:
            # Note: we use the key field as 0 just to ensure all inserts replace the existing row
            await self.db.execute("INSERT OR REPLACE INTO current_peak VALUES(?, ?)", (0, header_hash))
        else:
            await self.db.execute("UPDATE block_records SET is_peak=0 WHERE is_peak=1")
            await self.db.execute(
                "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
                (self.maybe_to_hex(header_hash),),
            )

    async def is_fully_compactified(self, header_hash: bytes32) -> Optional[bool]:
        async with self.db.execute(
            "SELECT is_fully_compactified from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
        ) as cursor:
            row = await cursor.fetchone()
        if row is None:
            return None
        return bool(row[0])

    async def get_random_not_compactified(self, number: int) -> List[int]:

        if self.db_wrapper.db_version == 2:
            async with self.db.execute(
                f"SELECT height FROM full_blocks WHERE in_main_chain=1 AND is_fully_compactified=0 "
                f"ORDER BY RANDOM() LIMIT {number}"
            ) as cursor:
                rows = await cursor.fetchall()
        else:
            # Since orphan blocks do not get compactified, we need to check whether all blocks with a
            # certain height are not compact. And if we do have compact orphan blocks, then all that
            # happens is that the occasional chain block stays uncompact - not ideal, but harmless.
            async with self.db.execute(
                f"SELECT height FROM full_blocks GROUP BY height HAVING sum(is_fully_compactified)=0 "
                f"ORDER BY RANDOM() LIMIT {number}"
            ) as cursor:
                rows = await cursor.fetchall()

        heights = [int(row[0]) for row in rows]

        return heights

    async def count_compactified_blocks(self) -> int:
        async with self.db.execute("select count(*) from full_blocks where is_fully_compactified=1") as cursor:
            row = await cursor.fetchone()

        assert row is not None

        [count] = row
        return int(count)

    async def count_uncompactified_blocks(self) -> int:
        async with self.db.execute("select count(*) from full_blocks where is_fully_compactified=0") as cursor:
            row = await cursor.fetchone()

        assert row is not None

        [count] = row
        return int(count)
Esempio n. 22
0
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        # All full blocks which have been added to the blockchain. Header_hash -> block
        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db

        if self.db_wrapper.db_version == 2:

            # TODO: most data in block is duplicated in block_record. The only
            # reason for this is that our parsing of a FullBlock is so slow,
            # it's faster to store duplicate data to parse less when we just
            # need the BlockRecord. Once we fix the parsing (and data structure)
            # of FullBlock, this can use less space
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks("
                "header_hash blob PRIMARY KEY,"
                "prev_hash blob,"
                "height bigint,"
                "sub_epoch_summary blob,"
                "is_fully_compactified tinyint,"
                "in_main_chain tinyint,"
                "block blob,"
                "block_record blob)"
            )

            # This is a single-row table containing the hash of the current
            # peak. The "key" field is there to make update statements simple
            await self.db.execute("CREATE TABLE IF NOT EXISTS current_peak(key int PRIMARY KEY, hash blob)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on full_blocks(height)")

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3("
                "ses_block_hash blob PRIMARY KEY,"
                "challenge_segments blob)"
            )

            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified ON"
                " full_blocks(is_fully_compactified, in_main_chain) WHERE in_main_chain=1"
            )
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS main_chain ON full_blocks(height, in_main_chain) WHERE in_main_chain=1"
            )

        else:

            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
                "  is_block tinyint, is_fully_compactified tinyint, block blob)"
            )

            # Block records
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS block_records(header_hash "
                "text PRIMARY KEY, prev_hash text, height bigint,"
                "block blob, sub_epoch_summary blob, is_peak tinyint, is_block tinyint)"
            )

            # Sub epoch segments for weight proofs
            await self.db.execute(
                "CREATE TABLE IF NOT EXISTS sub_epoch_segments_v3(ses_block_hash text PRIMARY KEY,"
                "challenge_segments blob)"
            )

            # Height index so we can look up in order of height for sync purposes
            await self.db.execute("CREATE INDEX IF NOT EXISTS full_block_height on full_blocks(height)")
            await self.db.execute(
                "CREATE INDEX IF NOT EXISTS is_fully_compactified on full_blocks(is_fully_compactified)"
            )

            await self.db.execute("CREATE INDEX IF NOT EXISTS height on block_records(height)")

            await self.db.execute("CREATE INDEX IF NOT EXISTS peak on block_records(is_peak) where is_peak = 1")

        await self.db.commit()
        self.block_cache = LRUCache(1000)
        self.ses_challenge_cache = LRUCache(50)
        return self
Esempio n. 23
0
class WalletBlockStore:
    """
    This object handles HeaderBlocks and Blocks stored in DB used by wallet.
    """

    db: aiosqlite.Connection
    db_wrapper: DBWrapper
    block_cache: LRUCache

    @classmethod
    async def create(cls, db_wrapper: DBWrapper):
        self = cls()

        self.db_wrapper = db_wrapper
        self.db = db_wrapper.db
        await self.db.execute("pragma journal_mode=wal")
        await self.db.execute("pragma synchronous=2")

        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS header_blocks(header_hash text PRIMARY KEY, height int,"
            " timestamp int, block blob)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS header_hash on header_blocks(header_hash)"
        )

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS timestamp on header_blocks(timestamp)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on header_blocks(height)")

        # Block records
        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS block_records(header_hash "
            "text PRIMARY KEY, prev_hash text, height bigint, weight bigint, total_iters text,"
            "block blob, sub_epoch_summary blob, is_peak tinyint)")

        await self.db.execute(
            "CREATE TABLE IF NOT EXISTS additional_coin_spends(header_hash text PRIMARY KEY, spends_list_blob blob)"
        )

        # Height index so we can look up in order of height for sync purposes
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS height on block_records(height)")

        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS hh on block_records(header_hash)")
        await self.db.execute(
            "CREATE INDEX IF NOT EXISTS peak on block_records(is_peak)")
        await self.db.commit()
        self.block_cache = LRUCache(1000)
        return self

    async def _clear_database(self):
        cursor_2 = await self.db.execute("DELETE FROM header_blocks")
        await cursor_2.close()
        await self.db.commit()

    async def add_block_record(
        self,
        header_block_record: HeaderBlockRecord,
        block_record: BlockRecord,
        additional_coin_spends: List[CoinSolution],
    ):
        """
        Adds a block record to the database. This block record is assumed to be connected
        to the chain, but it may or may not be in the LCA path.
        """
        cached = self.block_cache.get(header_block_record.header_hash)
        if cached is not None:
            # Since write to db can fail, we remove from cache here to avoid potential inconsistency
            # Adding to cache only from reading
            self.block_cache.put(header_block_record.header_hash, None)

        if header_block_record.header.foliage_transaction_block is not None:
            timestamp = header_block_record.header.foliage_transaction_block.timestamp
        else:
            timestamp = uint64(0)
        cursor = await self.db.execute(
            "INSERT OR REPLACE INTO header_blocks VALUES(?, ?, ?, ?)",
            (
                header_block_record.header_hash.hex(),
                header_block_record.height,
                timestamp,
                bytes(header_block_record),
            ),
        )

        await cursor.close()
        cursor_2 = await self.db.execute(
            "INSERT OR REPLACE INTO block_records VALUES(?, ?, ?, ?, ?, ?, ?,?)",
            (
                header_block_record.header.header_hash.hex(),
                header_block_record.header.prev_header_hash.hex(),
                header_block_record.header.height,
                header_block_record.header.weight.to_bytes(
                    128 // 8, "big", signed=False).hex(),
                header_block_record.header.total_iters.to_bytes(
                    128 // 8, "big", signed=False).hex(),
                bytes(block_record),
                None if block_record.sub_epoch_summary_included is None else
                bytes(block_record.sub_epoch_summary_included),
                False,
            ),
        )
        await cursor_2.close()

        if len(additional_coin_spends) > 0:
            blob: bytes = bytes(AdditionalCoinSpends(additional_coin_spends))
            cursor_3 = await self.db.execute(
                "INSERT OR REPLACE INTO additional_coin_spends VALUES(?, ?)",
                (header_block_record.header_hash, blob))
            await cursor_3.close()

    async def get_header_block_at(self,
                                  heights: List[uint32]) -> List[HeaderBlock]:
        if len(heights) == 0:
            return []

        heights_db = tuple(heights)
        formatted_str = f'SELECT block from header_blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)'
        cursor = await self.db.execute(formatted_str, heights_db)
        rows = await cursor.fetchall()
        await cursor.close()
        return [HeaderBlock.from_bytes(row[0]) for row in rows]

    async def get_header_block_record(
            self, header_hash: bytes32) -> Optional[HeaderBlockRecord]:
        """Gets a block record from the database, if present"""
        cached = self.block_cache.get(header_hash)
        if cached is not None:
            return cached
        cursor = await self.db.execute(
            "SELECT block from header_blocks WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            hbr: HeaderBlockRecord = HeaderBlockRecord.from_bytes(row[0])
            self.block_cache.put(hbr.header_hash, hbr)
            return hbr
        else:
            return None

    async def get_additional_coin_spends(
            self, header_hash: bytes32) -> Optional[List[CoinSolution]]:
        cursor = await self.db.execute(
            "SELECT spends_list_blob from additional_coin_spends WHERE header_hash=?",
            (header_hash.hex(), ))
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            coin_spends: AdditionalCoinSpends = AdditionalCoinSpends.from_bytes(
                row[0])
            return coin_spends.coin_spends_list
        else:
            return None

    async def get_block_record(self,
                               header_hash: bytes32) -> Optional[BlockRecord]:
        cursor = await self.db.execute(
            "SELECT block from block_records WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        row = await cursor.fetchone()
        await cursor.close()
        if row is not None:
            return BlockRecord.from_bytes(row[0])
        return None

    async def get_block_records(
        self, ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """
        cursor = await self.db.execute(
            "SELECT header_hash, block, is_peak from block_records")
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        peak: Optional[bytes32] = None
        for row in rows:
            header_hash_bytes, block_record_bytes, is_peak = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)
            if is_peak:
                assert peak is None  # Sanity check, only one peak
                peak = header_hash
        return ret, peak

    def rollback_cache_block(self, header_hash: bytes32):
        self.block_cache.remove(header_hash)

    async def set_peak(self, header_hash: bytes32) -> None:
        cursor_1 = await self.db.execute(
            "UPDATE block_records SET is_peak=0 WHERE is_peak=1")
        await cursor_1.close()
        cursor_2 = await self.db.execute(
            "UPDATE block_records SET is_peak=1 WHERE header_hash=?",
            (header_hash.hex(), ),
        )
        await cursor_2.close()

    async def get_block_records_close_to_peak(
            self, blocks_n: int
    ) -> Tuple[Dict[bytes32, BlockRecord], Optional[bytes32]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT header_hash, height from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, None
        header_hash_bytes, peak_height = row
        peak: bytes32 = bytes32(bytes.fromhex(header_hash_bytes))

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {peak_height - blocks_n}"
        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)
        return ret, peak

    async def get_header_blocks_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, HeaderBlock]:

        formatted_str = f"SELECT header_hash, block from header_blocks WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, HeaderBlock] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = HeaderBlock.from_bytes(block_record_bytes)

        return ret

    async def get_block_records_in_range(
        self,
        start: int,
        stop: int,
    ) -> Dict[bytes32, BlockRecord]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"

        cursor = await self.db.execute(formatted_str)
        rows = await cursor.fetchall()
        await cursor.close()
        ret: Dict[bytes32, BlockRecord] = {}
        for row in rows:
            header_hash_bytes, block_record_bytes = row
            header_hash = bytes.fromhex(header_hash_bytes)
            ret[header_hash] = BlockRecord.from_bytes(block_record_bytes)

        return ret

    async def get_peak_heights_dicts(
            self
    ) -> Tuple[Dict[uint32, bytes32], Dict[uint32, SubEpochSummary]]:
        """
        Returns a dictionary with all blocks, as well as the header hash of the peak,
        if present.
        """

        res = await self.db.execute(
            "SELECT header_hash from block_records WHERE is_peak = 1")
        row = await res.fetchone()
        await res.close()
        if row is None:
            return {}, {}

        peak: bytes32 = bytes.fromhex(row[0])
        cursor = await self.db.execute(
            "SELECT header_hash,prev_hash,height,sub_epoch_summary from block_records"
        )
        rows = await cursor.fetchall()
        await cursor.close()
        hash_to_prev_hash: Dict[bytes32, bytes32] = {}
        hash_to_height: Dict[bytes32, uint32] = {}
        hash_to_summary: Dict[bytes32, SubEpochSummary] = {}

        for row in rows:
            hash_to_prev_hash[bytes.fromhex(row[0])] = bytes.fromhex(row[1])
            hash_to_height[bytes.fromhex(row[0])] = row[2]
            if row[3] is not None:
                hash_to_summary[bytes.fromhex(
                    row[0])] = SubEpochSummary.from_bytes(row[3])

        height_to_hash: Dict[uint32, bytes32] = {}
        sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}

        curr_header_hash = peak
        curr_height = hash_to_height[curr_header_hash]
        while True:
            height_to_hash[curr_height] = curr_header_hash
            if curr_header_hash in hash_to_summary:
                sub_epoch_summaries[curr_height] = hash_to_summary[
                    curr_header_hash]
            if curr_height == 0:
                break
            curr_header_hash = hash_to_prev_hash[curr_header_hash]
            curr_height = hash_to_height[curr_header_hash]
        return height_to_hash, sub_epoch_summaries
Esempio n. 24
0
class WalletNode:
    key_config: Dict
    config: Dict
    constants: ConsensusConstants
    server: Optional[ChiaServer]
    log: logging.Logger
    wallet_peers: WalletPeers
    # Maintains the state of the wallet (blockchain and transactions), handles DB connections
    wallet_state_manager: Optional[WalletStateManager]

    # How far away from LCA we must be to perform a full sync. Before then, do a short sync,
    # which is consecutive requests for the previous block
    short_sync_threshold: int
    _shut_down: bool
    root_path: Path
    state_changed_callback: Optional[Callable]
    syncing: bool
    full_node_peer: Optional[PeerInfo]
    peer_task: Optional[asyncio.Task]
    logged_in: bool
    wallet_peers_initialized: bool

    def __init__(
        self,
        config: Dict,
        keychain: Keychain,
        root_path: Path,
        consensus_constants: ConsensusConstants,
        name: str = None,
    ):
        self.config = config
        self.constants = consensus_constants
        self.root_path = root_path
        self.log = logging.getLogger(name if name else __name__)
        # Normal operation data
        self.cached_blocks: Dict = {}
        self.future_block_hashes: Dict = {}
        self.keychain = keychain

        # Sync data
        self._shut_down = False
        self.proof_hashes: List = []
        self.header_hashes: List = []
        self.header_hashes_error = False
        self.short_sync_threshold = 15  # Change the test when changing this
        self.potential_blocks_received: Dict = {}
        self.potential_header_hashes: Dict = {}
        self.state_changed_callback = None
        self.wallet_state_manager = None
        self.backup_initialized = False  # Delay first launch sync after user imports backup info or decides to skip
        self.server = None
        self.wsm_close_task = None
        self.sync_task: Optional[asyncio.Task] = None
        self.new_peak_lock: Optional[asyncio.Lock] = None
        self.logged_in_fingerprint: Optional[int] = None
        self.peer_task = None
        self.logged_in = False
        self.wallet_peers_initialized = False
        self.last_new_peak_messages = LRUCache(5)

    def get_key_for_fingerprint(self, fingerprint: Optional[int]) -> Optional[PrivateKey]:
        private_keys = self.keychain.get_all_private_keys()
        if len(private_keys) == 0:
            self.log.warning("No keys present. Create keys with the UI, or with the 'chia keys' program.")
            return None

        private_key: Optional[PrivateKey] = None
        if fingerprint is not None:
            for sk, _ in private_keys:
                if sk.get_g1().get_fingerprint() == fingerprint:
                    private_key = sk
                    break
        else:
            private_key = private_keys[0][0]  # If no fingerprint, take the first private key
        return private_key

    async def _start(
        self,
        fingerprint: Optional[int] = None,
        new_wallet: bool = False,
        backup_file: Optional[Path] = None,
        skip_backup_import: bool = False,
    ) -> bool:
        private_key = self.get_key_for_fingerprint(fingerprint)
        if private_key is None:
            self.logged_in = False
            return False

        if self.config.get("enable_profiler", False):
            asyncio.create_task(profile_task(self.root_path, "wallet", self.log))

        db_path_key_suffix = str(private_key.get_g1().get_fingerprint())
        db_path_replaced: str = (
            self.config["database_path"]
            .replace("CHALLENGE", self.config["selected_network"])
            .replace("KEY", db_path_key_suffix)
        )
        path = path_from_root(self.root_path, db_path_replaced)
        mkdir(path.parent)

        assert self.server is not None
        self.wallet_state_manager = await WalletStateManager.create(
            private_key, self.config, path, self.constants, self.server
        )

        self.wsm_close_task = None

        assert self.wallet_state_manager is not None

        backup_settings: BackupInitialized = self.wallet_state_manager.user_settings.get_backup_settings()
        if backup_settings.user_initialized is False:
            if new_wallet is True:
                await self.wallet_state_manager.user_settings.user_created_new_wallet()
                self.wallet_state_manager.new_wallet = True
            elif skip_backup_import is True:
                await self.wallet_state_manager.user_settings.user_skipped_backup_import()
            elif backup_file is not None:
                await self.wallet_state_manager.import_backup_info(backup_file)
            else:
                self.backup_initialized = False
                await self.wallet_state_manager.close_all_stores()
                self.wallet_state_manager = None
                self.logged_in = False
                return False

        self.backup_initialized = True

        # Start peers here after the backup initialization has finished
        # We only want to do this once per instantiation
        # However, doing it earlier before backup initialization causes
        # the wallet to spam the introducer
        if self.wallet_peers_initialized is False:
            asyncio.create_task(self.wallet_peers.start())
            self.wallet_peers_initialized = True

        if backup_file is not None:
            json_dict = open_backup_file(backup_file, self.wallet_state_manager.private_key)
            if "start_height" in json_dict["data"]:
                start_height = json_dict["data"]["start_height"]
                self.config["starting_height"] = max(0, start_height - self.config["start_height_buffer"])
            else:
                self.config["starting_height"] = 0
        else:
            self.config["starting_height"] = 0

        if self.state_changed_callback is not None:
            self.wallet_state_manager.set_callback(self.state_changed_callback)

        self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
        self._shut_down = False

        self.peer_task = asyncio.create_task(self._periodically_check_full_node())
        self.sync_event = asyncio.Event()
        self.sync_task = asyncio.create_task(self.sync_job())
        self.logged_in_fingerprint = fingerprint
        self.logged_in = True
        return True

    def _close(self):
        self.log.info("self._close")
        self.logged_in_fingerprint = None
        self._shut_down = True

    async def _await_closed(self):
        self.log.info("self._await_closed")
        await self.server.close_all_connections()
        asyncio.create_task(self.wallet_peers.ensure_is_closed())
        if self.wallet_state_manager is not None:
            await self.wallet_state_manager.close_all_stores()
            self.wallet_state_manager = None
        if self.sync_task is not None:
            self.sync_task.cancel()
            self.sync_task = None
        if self.peer_task is not None:
            self.peer_task.cancel()
            self.peer_task = None
        self.logged_in = False

    def _set_state_changed_callback(self, callback: Callable):
        self.state_changed_callback = callback

        if self.wallet_state_manager is not None:
            self.wallet_state_manager.set_callback(self.state_changed_callback)
            self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)

    def _pending_tx_handler(self):
        if self.wallet_state_manager is None or self.backup_initialized is False:
            return None
        asyncio.create_task(self._resend_queue())

    async def _action_messages(self) -> List[Message]:
        if self.wallet_state_manager is None or self.backup_initialized is False:
            return []
        actions: List[WalletAction] = await self.wallet_state_manager.action_store.get_all_pending_actions()
        result: List[Message] = []
        for action in actions:
            data = json.loads(action.data)
            action_data = data["data"]["action_data"]
            if action.name == "request_puzzle_solution":
                coin_name = bytes32(hexstr_to_bytes(action_data["coin_name"]))
                height = uint32(action_data["height"])
                msg = make_msg(
                    ProtocolMessageTypes.request_puzzle_solution,
                    wallet_protocol.RequestPuzzleSolution(coin_name, height),
                )
                result.append(msg)

        return result

    async def _resend_queue(self):
        if (
            self._shut_down
            or self.server is None
            or self.wallet_state_manager is None
            or self.backup_initialized is None
        ):
            return None

        for msg, sent_peers in await self._messages_to_resend():
            if (
                self._shut_down
                or self.server is None
                or self.wallet_state_manager is None
                or self.backup_initialized is None
            ):
                return None
            full_nodes = self.server.get_full_node_connections()
            for peer in full_nodes:
                if peer.peer_node_id in sent_peers:
                    continue
                await peer.send_message(msg)

        for msg in await self._action_messages():
            if (
                self._shut_down
                or self.server is None
                or self.wallet_state_manager is None
                or self.backup_initialized is None
            ):
                return None
            await self.server.send_to_all([msg], NodeType.FULL_NODE)

    async def _messages_to_resend(self) -> List[Tuple[Message, Set[bytes32]]]:
        if self.wallet_state_manager is None or self.backup_initialized is False or self._shut_down:
            return []
        messages: List[Tuple[Message, Set[bytes32]]] = []

        records: List[TransactionRecord] = await self.wallet_state_manager.tx_store.get_not_sent()

        for record in records:
            if record.spend_bundle is None:
                continue
            msg = make_msg(
                ProtocolMessageTypes.send_transaction,
                wallet_protocol.SendTransaction(record.spend_bundle),
            )
            already_sent = set()
            for peer, status, _ in record.sent_to:
                already_sent.add(hexstr_to_bytes(peer))
            messages.append((msg, already_sent))

        return messages

    def set_server(self, server: ChiaServer):
        self.server = server
        DNS_SERVERS_EMPTY: list = []
        # TODO: Perhaps use a different set of DNS seeders for wallets, to split the traffic.
        self.wallet_peers = WalletPeers(
            self.server,
            self.root_path,
            self.config["target_peer_count"],
            self.config["wallet_peers_path"],
            self.config["introducer_peer"],
            DNS_SERVERS_EMPTY,
            self.config["peer_connect_interval"],
            self.config["selected_network"],
            None,
            self.log,
        )

    async def on_connect(self, peer: WSChiaConnection):
        if self.wallet_state_manager is None or self.backup_initialized is False:
            return None
        messages_peer_ids = await self._messages_to_resend()
        self.wallet_state_manager.state_changed("add_connection")
        for msg, peer_ids in messages_peer_ids:
            if peer.peer_node_id in peer_ids:
                continue
            await peer.send_message(msg)
        if not self.has_full_node() and self.wallet_peers is not None:
            asyncio.create_task(self.wallet_peers.on_connect(peer))

    async def _periodically_check_full_node(self) -> None:
        tries = 0
        while not self._shut_down and tries < 5:
            if self.has_full_node():
                await self.wallet_peers.ensure_is_closed()
                if self.wallet_state_manager is not None:
                    self.wallet_state_manager.state_changed("add_connection")
                break
            tries += 1
            await asyncio.sleep(self.config["peer_connect_interval"])

    def has_full_node(self) -> bool:
        if self.server is None:
            return False
        if "full_node_peer" in self.config:
            full_node_peer = PeerInfo(
                self.config["full_node_peer"]["host"],
                self.config["full_node_peer"]["port"],
            )
            peers = [c.get_peer_info() for c in self.server.get_full_node_connections()]
            full_node_resolved = PeerInfo(socket.gethostbyname(full_node_peer.host), full_node_peer.port)
            if full_node_peer in peers or full_node_resolved in peers:
                self.log.info(f"Will not attempt to connect to other nodes, already connected to {full_node_peer}")
                for connection in self.server.get_full_node_connections():
                    if (
                        connection.get_peer_info() != full_node_peer
                        and connection.get_peer_info() != full_node_resolved
                    ):
                        self.log.info(f"Closing unnecessary connection to {connection.get_peer_info()}.")
                        asyncio.create_task(connection.close())
                return True
        return False

    async def complete_blocks(self, header_blocks: List[HeaderBlock], peer: WSChiaConnection):
        if self.wallet_state_manager is None:
            return None
        header_block_records: List[HeaderBlockRecord] = []
        assert self.server
        trusted = self.server.is_trusted_peer(peer, self.config["trusted_peers"])
        async with self.wallet_state_manager.blockchain.lock:
            for block in header_blocks:
                if block.is_transaction_block:
                    # Find additions and removals
                    (additions, removals,) = await self.wallet_state_manager.get_filter_additions_removals(
                        block, block.transactions_filter, None
                    )

                    # Get Additions
                    added_coins = await self.get_additions(peer, block, additions)
                    if added_coins is None:
                        raise ValueError("Failed to fetch additions")

                    # Get removals
                    removed_coins = await self.get_removals(peer, block, added_coins, removals)
                    if removed_coins is None:
                        raise ValueError("Failed to fetch removals")
                    hbr = HeaderBlockRecord(block, added_coins, removed_coins)
                else:
                    hbr = HeaderBlockRecord(block, [], [])
                    header_block_records.append(hbr)
                (
                    result,
                    error,
                    fork_h,
                ) = await self.wallet_state_manager.blockchain.receive_block(hbr, trusted=trusted)
                if result == ReceiveBlockResult.NEW_PEAK:
                    if not self.wallet_state_manager.sync_mode:
                        self.wallet_state_manager.blockchain.clean_block_records()
                    self.wallet_state_manager.state_changed("new_block")
                    self.wallet_state_manager.state_changed("sync_changed")
                elif result == ReceiveBlockResult.INVALID_BLOCK:
                    self.log.info(f"Invalid block from peer: {peer.get_peer_info()} {error}")
                    await peer.close()
                    return None
                else:
                    self.log.debug(f"Result: {result}")

    async def new_peak_wallet(self, peak: wallet_protocol.NewPeakWallet, peer: WSChiaConnection):
        if self.wallet_state_manager is None:
            return None

        curr_peak = self.wallet_state_manager.blockchain.get_peak()
        if curr_peak is not None and curr_peak.weight >= peak.weight:
            return None
        if self.new_peak_lock is None:
            self.new_peak_lock = asyncio.Lock()
        async with self.new_peak_lock:
            request = wallet_protocol.RequestBlockHeader(peak.height)
            response: Optional[RespondBlockHeader] = await peer.request_block_header(request)

            if response is None or not isinstance(response, RespondBlockHeader) or response.header_block is None:
                return None

            header_block = response.header_block

            if (curr_peak is None and header_block.height < self.constants.WEIGHT_PROOF_RECENT_BLOCKS) or (
                curr_peak is not None and curr_peak.height > header_block.height - 200
            ):
                top = header_block
                blocks = [top]
                # Fetch blocks backwards until we hit the one that we have,
                # then complete them with additions / removals going forward
                while not self.wallet_state_manager.blockchain.contains_block(top.prev_header_hash) and top.height > 0:
                    request_prev = wallet_protocol.RequestBlockHeader(top.height - 1)
                    response_prev: Optional[RespondBlockHeader] = await peer.request_block_header(request_prev)
                    if response_prev is None:
                        return None
                    if not isinstance(response_prev, RespondBlockHeader):
                        return None
                    prev_head = response_prev.header_block
                    blocks.append(prev_head)
                    top = prev_head
                blocks.reverse()
                await self.complete_blocks(blocks, peer)
                await self.wallet_state_manager.create_more_puzzle_hashes()
            elif header_block.height >= self.constants.WEIGHT_PROOF_RECENT_BLOCKS:
                # Request weight proof
                # Sync if PoW validates
                if self.wallet_state_manager.sync_mode:
                    self.last_new_peak_messages.put(peer, peak)
                    return None
                weight_request = RequestProofOfWeight(header_block.height, header_block.header_hash)
                weight_proof_response: RespondProofOfWeight = await peer.request_proof_of_weight(
                    weight_request, timeout=360
                )
                if weight_proof_response is None:
                    return None
                weight_proof = weight_proof_response.wp
                if self.wallet_state_manager is None:
                    return None
                if self.server is not None and self.server.is_trusted_peer(peer, self.config["trusted_peers"]):
                    valid, fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point_no_validations(
                        weight_proof
                    )
                else:
                    valid, fork_point, _ = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof(
                        weight_proof
                    )
                    if not valid:
                        self.log.error(
                            f"invalid weight proof, num of epochs {len(weight_proof.sub_epochs)}"
                            f" recent blocks num ,{len(weight_proof.recent_chain_data)}"
                        )
                        self.log.debug(f"{weight_proof}")
                        return None
                self.log.info(f"Validated, fork point is {fork_point}")
                self.wallet_state_manager.sync_store.add_potential_fork_point(
                    header_block.header_hash, uint32(fork_point)
                )
                self.wallet_state_manager.sync_store.add_potential_peak(header_block)
                self.start_sync()

    def start_sync(self) -> None:
        self.log.info("self.sync_event.set()")
        self.sync_event.set()

    async def check_new_peak(self) -> None:
        if self.wallet_state_manager is None:
            return None

        current_peak: Optional[BlockRecord] = self.wallet_state_manager.blockchain.get_peak()
        if current_peak is None:
            return None
        potential_peaks: List[
            Tuple[bytes32, HeaderBlock]
        ] = self.wallet_state_manager.sync_store.get_potential_peaks_tuples()
        for _, block in potential_peaks:
            if current_peak.weight < block.weight:
                await asyncio.sleep(5)
                self.start_sync()
                return None

    async def sync_job(self) -> None:
        while True:
            self.log.info("Loop start in sync job")
            if self._shut_down is True:
                break
            asyncio.create_task(self.check_new_peak())
            await self.sync_event.wait()
            self.last_new_peak_messages = LRUCache(5)
            self.sync_event.clear()

            if self._shut_down is True:
                break
            try:
                assert self.wallet_state_manager is not None
                self.wallet_state_manager.set_sync_mode(True)
                await self._sync()
            except Exception as e:
                tb = traceback.format_exc()
                self.log.error(f"Loop exception in sync {e}. {tb}")
            finally:
                if self.wallet_state_manager is not None:
                    self.wallet_state_manager.set_sync_mode(False)
                for peer, peak in self.last_new_peak_messages.cache.items():
                    asyncio.create_task(self.new_peak_wallet(peak, peer))
            self.log.info("Loop end in sync job")

    async def _sync(self) -> None:
        """
        Wallet has fallen far behind (or is starting up for the first time), and must be synced
        up to the LCA of the blockchain.
        """
        if self.wallet_state_manager is None or self.backup_initialized is False or self.server is None:
            return None

        highest_weight: uint128 = uint128(0)
        peak_height: uint32 = uint32(0)
        peak: Optional[HeaderBlock] = None
        potential_peaks: List[
            Tuple[bytes32, HeaderBlock]
        ] = self.wallet_state_manager.sync_store.get_potential_peaks_tuples()

        self.log.info(f"Have collected {len(potential_peaks)} potential peaks")

        for header_hash, potential_peak_block in potential_peaks:
            if potential_peak_block.weight > highest_weight:
                highest_weight = potential_peak_block.weight
                peak_height = potential_peak_block.height
                peak = potential_peak_block

        if peak_height is None or peak_height == 0:
            return None

        if self.wallet_state_manager.peak is not None and highest_weight <= self.wallet_state_manager.peak.weight:
            self.log.info("Not performing sync, already caught up.")
            return None

        peers: List[WSChiaConnection] = self.server.get_full_node_connections()
        if len(peers) == 0:
            self.log.info("No peers to sync to")
            return None

        async with self.wallet_state_manager.blockchain.lock:
            fork_height = None
            if peak is not None:
                fork_height = self.wallet_state_manager.sync_store.get_potential_fork_point(peak.header_hash)
                our_peak_height = self.wallet_state_manager.blockchain.get_peak_height()
                ses_heigths = self.wallet_state_manager.blockchain.get_ses_heights()
                if len(ses_heigths) > 2 and our_peak_height is not None:
                    ses_heigths.sort()
                    max_fork_ses_height = ses_heigths[-3]
                    # This is the fork point in SES in the case where no fork was detected
                    if (
                        self.wallet_state_manager.blockchain.get_peak_height() is not None
                        and fork_height == max_fork_ses_height
                    ):
                        peers = self.server.get_full_node_connections()
                        for peer in peers:
                            # Grab a block at peak + 1 and check if fork point is actually our current height
                            potential_height = uint32(our_peak_height + 1)
                            block_response: Optional[Any] = await peer.request_header_blocks(
                                wallet_protocol.RequestHeaderBlocks(potential_height, potential_height)
                            )
                            if block_response is not None and isinstance(
                                block_response, wallet_protocol.RespondHeaderBlocks
                            ):
                                our_peak = self.wallet_state_manager.blockchain.get_peak()
                                if (
                                    our_peak is not None
                                    and block_response.header_blocks[0].prev_header_hash == our_peak.header_hash
                                ):
                                    fork_height = our_peak_height
                                break
            if fork_height is None:
                fork_height = uint32(0)
            await self.wallet_state_manager.blockchain.warmup(fork_height)
            batch_size = self.constants.MAX_BLOCK_COUNT_PER_REQUESTS
            advanced_peak = False
            for i in range(max(0, fork_height - 1), peak_height, batch_size):
                start_height = i
                end_height = min(peak_height, start_height + batch_size)
                peers = self.server.get_full_node_connections()
                added = False
                for peer in peers:
                    try:
                        added, advanced_peak = await self.fetch_blocks_and_validate(
                            peer, uint32(start_height), uint32(end_height), None if advanced_peak else fork_height
                        )
                        if added:
                            break
                    except Exception as e:
                        await peer.close()
                        exc = traceback.format_exc()
                        self.log.error(f"Error while trying to fetch from peer:{e} {exc}")
                if not added:
                    raise RuntimeError(f"Was not able to add blocks {start_height}-{end_height}")

                peak = self.wallet_state_manager.blockchain.get_peak()
                assert peak is not None
                self.wallet_state_manager.blockchain.clean_block_record(
                    min(
                        end_height - self.constants.BLOCKS_CACHE_SIZE,
                        peak.height - self.constants.BLOCKS_CACHE_SIZE,
                    )
                )

    async def fetch_blocks_and_validate(
        self,
        peer: WSChiaConnection,
        height_start: uint32,
        height_end: uint32,
        fork_point_with_peak: Optional[uint32],
    ) -> Tuple[bool, bool]:
        """
        Returns whether the blocks validated, and whether the peak was advanced
        """
        if self.wallet_state_manager is None:
            return False, False

        self.log.info(f"Requesting blocks {height_start}-{height_end}")
        request = RequestHeaderBlocks(uint32(height_start), uint32(height_end))
        res: Optional[RespondHeaderBlocks] = await peer.request_header_blocks(request)
        if res is None or not isinstance(res, RespondHeaderBlocks):
            raise ValueError("Peer returned no response")
        header_blocks: List[HeaderBlock] = res.header_blocks
        advanced_peak = False
        if header_blocks is None:
            raise ValueError(f"No response from peer {peer}")
        if (
            self.full_node_peer is not None
            and peer.peer_host == self.full_node_peer.host
            or peer.peer_host == "127.0.0.1"
        ):
            trusted = True
            pre_validation_results: Optional[List[PreValidationResult]] = None
        else:
            trusted = False
            pre_validation_results = await self.wallet_state_manager.blockchain.pre_validate_blocks_multiprocessing(
                header_blocks
            )
            if pre_validation_results is None:
                return False, advanced_peak
            assert len(header_blocks) == len(pre_validation_results)

        for i in range(len(header_blocks)):
            header_block = header_blocks[i]
            if not trusted and pre_validation_results is not None and pre_validation_results[i].error is not None:
                raise ValidationError(Err(pre_validation_results[i].error))

            fork_point_with_old_peak = None if advanced_peak else fork_point_with_peak
            if header_block.is_transaction_block:
                # Find additions and removals
                (additions, removals,) = await self.wallet_state_manager.get_filter_additions_removals(
                    header_block, header_block.transactions_filter, fork_point_with_old_peak
                )

                # Get Additions
                added_coins = await self.get_additions(peer, header_block, additions)
                if added_coins is None:
                    raise ValueError("Failed to fetch additions")

                # Get removals
                removed_coins = await self.get_removals(peer, header_block, added_coins, removals)
                if removed_coins is None:
                    raise ValueError("Failed to fetch removals")

                header_block_record = HeaderBlockRecord(header_block, added_coins, removed_coins)
            else:
                header_block_record = HeaderBlockRecord(header_block, [], [])
            start_t = time.time()
            if trusted:
                (result, error, fork_h,) = await self.wallet_state_manager.blockchain.receive_block(
                    header_block_record, None, trusted, fork_point_with_old_peak
                )
            else:
                assert pre_validation_results is not None
                (result, error, fork_h,) = await self.wallet_state_manager.blockchain.receive_block(
                    header_block_record, pre_validation_results[i], trusted, fork_point_with_old_peak
                )
            self.log.debug(
                f"Time taken to validate {header_block.height} with fork "
                f"{fork_point_with_old_peak}: {time.time() - start_t}"
            )
            if result == ReceiveBlockResult.NEW_PEAK:
                advanced_peak = True
                self.wallet_state_manager.state_changed("new_block")
            elif result == ReceiveBlockResult.INVALID_BLOCK:
                raise ValueError("Value error peer sent us invalid block")
        if advanced_peak:
            await self.wallet_state_manager.create_more_puzzle_hashes()
        return True, advanced_peak

    def validate_additions(
        self,
        coins: List[Tuple[bytes32, List[Coin]]],
        proofs: Optional[List[Tuple[bytes32, bytes, Optional[bytes]]]],
        root,
    ):
        if proofs is None:
            # Verify root
            additions_merkle_set = MerkleSet()

            # Addition Merkle set contains puzzlehash and hash of all coins with that puzzlehash
            for puzzle_hash, coins_l in coins:
                additions_merkle_set.add_already_hashed(puzzle_hash)
                additions_merkle_set.add_already_hashed(hash_coin_list(coins_l))

            additions_root = additions_merkle_set.get_root()
            if root != additions_root:
                return False
        else:
            for i in range(len(coins)):
                assert coins[i][0] == proofs[i][0]
                coin_list_1: List[Coin] = coins[i][1]
                puzzle_hash_proof: bytes32 = proofs[i][1]
                coin_list_proof: Optional[bytes32] = proofs[i][2]
                if len(coin_list_1) == 0:
                    # Verify exclusion proof for puzzle hash
                    not_included = confirm_not_included_already_hashed(
                        root,
                        coins[i][0],
                        puzzle_hash_proof,
                    )
                    if not_included is False:
                        return False
                else:
                    try:
                        # Verify inclusion proof for coin list
                        included = confirm_included_already_hashed(
                            root,
                            hash_coin_list(coin_list_1),
                            coin_list_proof,
                        )
                        if included is False:
                            return False
                    except AssertionError:
                        return False
                    try:
                        # Verify inclusion proof for puzzle hash
                        included = confirm_included_already_hashed(
                            root,
                            coins[i][0],
                            puzzle_hash_proof,
                        )
                        if included is False:
                            return False
                    except AssertionError:
                        return False

        return True

    def validate_removals(self, coins, proofs, root):
        if proofs is None:
            # If there are no proofs, it means all removals were returned in the response.
            # we must find the ones relevant to our wallets.

            # Verify removals root
            removals_merkle_set = MerkleSet()
            for name_coin in coins:
                # TODO review all verification
                name, coin = name_coin
                if coin is not None:
                    removals_merkle_set.add_already_hashed(coin.name())
            removals_root = removals_merkle_set.get_root()
            if root != removals_root:
                return False
        else:
            # This means the full node has responded only with the relevant removals
            # for our wallet. Each merkle proof must be verified.
            if len(coins) != len(proofs):
                return False
            for i in range(len(coins)):
                # Coins are in the same order as proofs
                if coins[i][0] != proofs[i][0]:
                    return False
                coin = coins[i][1]
                if coin is None:
                    # Verifies merkle proof of exclusion
                    not_included = confirm_not_included_already_hashed(
                        root,
                        coins[i][0],
                        proofs[i][1],
                    )
                    if not_included is False:
                        return False
                else:
                    # Verifies merkle proof of inclusion of coin name
                    if coins[i][0] != coin.name():
                        return False
                    included = confirm_included_already_hashed(
                        root,
                        coin.name(),
                        proofs[i][1],
                    )
                    if included is False:
                        return False
        return True

    async def get_additions(self, peer: WSChiaConnection, block_i, additions) -> Optional[List[Coin]]:
        if len(additions) > 0:
            additions_request = RequestAdditions(block_i.height, block_i.header_hash, additions)
            additions_res: Optional[Union[RespondAdditions, RejectAdditionsRequest]] = await peer.request_additions(
                additions_request
            )
            if additions_res is None:
                await peer.close()
                return None
            elif isinstance(additions_res, RespondAdditions):
                validated = self.validate_additions(
                    additions_res.coins,
                    additions_res.proofs,
                    block_i.foliage_transaction_block.additions_root,
                )
                if not validated:
                    await peer.close()
                    return None
                added_coins = []
                for ph_coins in additions_res.coins:
                    ph, coins = ph_coins
                    added_coins.extend(coins)
                return added_coins
            elif isinstance(additions_res, RejectRemovalsRequest):
                await peer.close()
                return None
            return None
        else:
            return []  # No added coins

    async def get_removals(self, peer: WSChiaConnection, block_i, additions, removals) -> Optional[List[Coin]]:
        assert self.wallet_state_manager is not None
        request_all_removals = False
        # Check if we need all removals
        for coin in additions:
            puzzle_store = self.wallet_state_manager.puzzle_store
            record_info: Optional[DerivationRecord] = await puzzle_store.get_derivation_record_for_puzzle_hash(
                coin.puzzle_hash.hex()
            )
            if record_info is not None and record_info.wallet_type == WalletType.COLOURED_COIN:
                # TODO why ?
                request_all_removals = True
                break
            if record_info is not None and record_info.wallet_type == WalletType.DISTRIBUTED_ID:
                request_all_removals = True
                break

        if len(removals) > 0 or request_all_removals:
            if request_all_removals:
                removals_request = wallet_protocol.RequestRemovals(block_i.height, block_i.header_hash, None)
            else:
                removals_request = wallet_protocol.RequestRemovals(block_i.height, block_i.header_hash, removals)
            removals_res: Optional[Union[RespondRemovals, RejectRemovalsRequest]] = await peer.request_removals(
                removals_request
            )
            if removals_res is None:
                return None
            elif isinstance(removals_res, RespondRemovals):
                validated = self.validate_removals(
                    removals_res.coins,
                    removals_res.proofs,
                    block_i.foliage_transaction_block.removals_root,
                )
                if validated is False:
                    await peer.close()
                    return None
                removed_coins = []
                for _, coins_l in removals_res.coins:
                    if coins_l is not None:
                        removed_coins.append(coins_l)

                return removed_coins
            elif isinstance(removals_res, RejectRemovalsRequest):
                return None
            else:
                return None

        else:
            return []
Esempio n. 25
0
                return []
        pairings.append(pairing)

    for i, pairing in enumerate(pairings):
        if pairing is None:
            aug_msg = bytes(pks[i]) + msgs[i]
            aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg)
            pairing = pks[i].pair(aug_hash)

            h = bytes(std_hash(aug_msg))
            cache.put(h, pairing)
            pairings[i] = pairing

    return pairings


LOCAL_CACHE: LRUCache = LRUCache(10000)


def aggregate_verify(pks: List[G1Element],
                     msgs: List[bytes],
                     sig: G2Element,
                     force_cache: bool = False,
                     cache: LRUCache = LOCAL_CACHE):
    pairings: List[GTElement] = get_pairings(cache, pks, msgs, force_cache)
    if len(pairings) == 0:
        return AugSchemeMPL.aggregate_verify(pks, msgs, sig)

    pairings_prod: GTElement = functools.reduce(GTElement.__mul__, pairings)
    return pairings_prod == sig.pair(G1Element.generator())