Beispiel #1
0
class ChainManager:
    def __init__(self, state):
        self.state = state
        self.tx_pool = TransactionPool(None)
        self.last_block = Block.from_json(GenesisBlock().to_json())
        self.current_difficulty = StringToUInt256(
            str(config.dev.genesis_difficulty))

        self.trigger_miner = False

    @property
    def height(self):
        return self.last_block.block_number

    def set_broadcast_tx(self, broadcast_tx):
        self.tx_pool.set_broadcast_tx(broadcast_tx)

    def get_last_block(self) -> Block:
        return self.last_block

    def get_cumulative_difficulty(self):
        last_block_metadata = self.state.get_block_metadata(
            self.last_block.headerhash)
        return last_block_metadata.cumulative_difficulty

    def load(self, genesis_block):
        height = self.state.get_mainchain_height()

        if height == -1:
            self.state.put_block(genesis_block, None)
            block_number_mapping = qrl_pb2.BlockNumberMapping(
                headerhash=genesis_block.headerhash,
                prev_headerhash=genesis_block.prev_headerhash)

            self.state.put_block_number_mapping(genesis_block.block_number,
                                                block_number_mapping, None)
            parent_difficulty = StringToUInt256(
                str(config.dev.genesis_difficulty))

            self.current_difficulty, _ = DifficultyTracker.get(
                measurement=config.dev.mining_setpoint_blocktime,
                parent_difficulty=parent_difficulty)

            block_metadata = BlockMetadata.create()

            block_metadata.set_orphan(False)
            block_metadata.set_block_difficulty(self.current_difficulty)
            block_metadata.set_cumulative_difficulty(self.current_difficulty)

            self.state.put_block_metadata(genesis_block.headerhash,
                                          block_metadata, None)
            addresses_state = dict()
            for genesis_balance in GenesisBlock().genesis_balance:
                bytes_addr = genesis_balance.address
                addresses_state[bytes_addr] = AddressState.get_default(
                    bytes_addr)
                addresses_state[
                    bytes_addr]._data.balance = genesis_balance.balance

            for tx_idx in range(1, len(genesis_block.transactions)):
                tx = Transaction.from_pbdata(
                    genesis_block.transactions[tx_idx])
                for addr in tx.addrs_to:
                    addresses_state[addr] = AddressState.get_default(addr)

            coinbase_tx = Transaction.from_pbdata(
                genesis_block.transactions[0])

            if not isinstance(coinbase_tx, CoinBase):
                return False

            addresses_state[coinbase_tx.addr_to] = AddressState.get_default(
                coinbase_tx.addr_to)

            if not coinbase_tx.validate_extended():
                return False

            coinbase_tx.apply_state_changes(addresses_state)

            for tx_idx in range(1, len(genesis_block.transactions)):
                tx = Transaction.from_pbdata(
                    genesis_block.transactions[tx_idx])
                tx.apply_state_changes(addresses_state)

            self.state.put_addresses_state(addresses_state)
            self.state.update_tx_metadata(genesis_block, None)
            self.state.update_mainchain_height(0, None)
        else:
            self.last_block = self.get_block_by_number(height)
            self.current_difficulty = self.state.get_block_metadata(
                self.last_block.headerhash).block_difficulty

    def validate_mining_nonce(self, block, enable_logging=False):
        parent_metadata = self.state.get_block_metadata(block.prev_headerhash)
        parent_block = self.state.get_block(block.prev_headerhash)

        measurement = self.state.get_measurement(block.timestamp,
                                                 block.prev_headerhash,
                                                 parent_metadata)
        diff, target = DifficultyTracker.get(
            measurement=measurement,
            parent_difficulty=parent_metadata.block_difficulty)

        if enable_logging:
            logger.debug('-----------------START--------------------')
            logger.debug('Validate #%s', block.block_number)
            logger.debug('block.timestamp %s', block.timestamp)
            logger.debug('parent_block.timestamp %s', parent_block.timestamp)
            logger.debug('parent_block.difficulty %s',
                         UInt256ToString(parent_metadata.block_difficulty))
            logger.debug('diff : %s | target : %s', UInt256ToString(diff),
                         target)
            logger.debug('-------------------END--------------------')

        if not self.verify_input_cached(block.mining_blob, target):
            if enable_logging:
                logger.warning("PoW verification failed")
                qn = Qryptonight()
                tmp_hash = qn.hash(block.mining_blob)
                logger.warning("{}".format(tmp_hash))
                logger.debug('%s', block.to_json())
            return False

        return True

    @functools.lru_cache(maxsize=5)
    def verify_input_cached(self, mining_blob, target):
        return PoWHelper.verifyInput(mining_blob, target)

    def _pre_check(self, block):

        if self.state.get_block(block.headerhash):  # Duplicate block check
            logger.info('Duplicate block %s %s', block.block_number,
                        bin2hstr(block.headerhash))
            return False

        return True

    def _try_orphan_add_block(self, block, batch):
        prev_block_metadata = self.state.get_block_metadata(
            block.prev_headerhash)

        if prev_block_metadata is None or prev_block_metadata.is_orphan:
            self.state.put_block(block, batch)
            self.add_block_metadata(block.headerhash, block.timestamp,
                                    block.prev_headerhash, batch)
            return True

        return False

    def _try_branch_add_block(self, block, batch=None) -> bool:
        address_set = self.state.prepare_address_list(
            block)  # Prepare list for current block
        if self.last_block.headerhash == block.prev_headerhash:
            address_txn = self.state.get_state_mainchain(address_set)
        else:
            address_txn, rollback_headerhash, hash_path = self.state.get_state(
                block.prev_headerhash, address_set)

        if block.apply_state_changes(address_txn):
            self.state.put_block(block, None)
            self.add_block_metadata(block.headerhash, block.timestamp,
                                    block.prev_headerhash, None)

            last_block_metadata = self.state.get_block_metadata(
                self.last_block.headerhash)
            new_block_metadata = self.state.get_block_metadata(
                block.headerhash)
            last_block_difficulty = int(
                UInt256ToString(last_block_metadata.cumulative_difficulty))
            new_block_difficulty = int(
                UInt256ToString(new_block_metadata.cumulative_difficulty))

            if new_block_difficulty > last_block_difficulty:
                if self.last_block.headerhash != block.prev_headerhash:
                    self.rollback(rollback_headerhash, hash_path,
                                  block.block_number)

                self.state.put_addresses_state(address_txn)
                self.last_block = block
                self._update_mainchain(block, batch)
                self.tx_pool.remove_tx_in_block_from_pool(block)
                self.tx_pool.check_stale_txn(block.block_number)
                self.state.update_mainchain_height(block.block_number, batch)
                self.state.update_tx_metadata(block, batch)

                self.trigger_miner = True

            return True

        return False

    def remove_block_from_mainchain(self, block: Block,
                                    latest_block_number: int, batch):
        addresses_set = self.state.prepare_address_list(block)
        addresses_state = self.state.get_state_mainchain(addresses_set)
        for tx_idx in range(len(block.transactions) - 1, 0, -1):
            tx = Transaction.from_pbdata(block.transactions[tx_idx])
            tx.revert_state_changes(addresses_state, self.state)

        self.tx_pool.add_tx_from_block_to_pool(block, latest_block_number)
        self.state.update_mainchain_height(block.block_number - 1, batch)
        self.state.rollback_tx_metadata(block, batch)
        self.state.remove_blocknumber_mapping(block.block_number, batch)
        self.state.put_addresses_state(addresses_state, batch)

    def rollback(self, rollback_headerhash, hash_path, latest_block_number):
        while self.last_block.headerhash != rollback_headerhash:
            self.remove_block_from_mainchain(self.last_block,
                                             latest_block_number, None)
            self.last_block = self.state.get_block(
                self.last_block.prev_headerhash)

        for header_hash in hash_path[-1::-1]:
            block = self.state.get_block(header_hash)
            address_set = self.state.prepare_address_list(
                block)  # Prepare list for current block
            addresses_state = self.state.get_state_mainchain(address_set)

            for tx_idx in range(0, len(block.transactions)):
                tx = Transaction.from_pbdata(block.transactions[tx_idx])
                tx.apply_state_changes(addresses_state)

            self.state.put_addresses_state(addresses_state)
            self.last_block = block
            self._update_mainchain(block, None)
            self.tx_pool.remove_tx_in_block_from_pool(block)
            self.state.update_mainchain_height(block.block_number, None)
            self.state.update_tx_metadata(block, None)

        self.trigger_miner = True

    def _add_block(self, block, batch=None):
        self.trigger_miner = False

        block_size_limit = self.state.get_block_size_limit(block)
        if block_size_limit and block.size > block_size_limit:
            logger.info('Block Size greater than threshold limit %s > %s',
                        block.size, block_size_limit)
            return False

        if self._try_orphan_add_block(block, batch):
            return True

        if self._try_branch_add_block(block, batch):
            return True

        return False

    def add_block(self, block: Block) -> bool:
        if block.block_number < self.height - config.dev.reorg_limit:
            logger.debug('Skipping block #%s as beyond re-org limit',
                         block.block_number)
            return False

        batch = self.state.get_batch()
        if self._add_block(block, batch=batch):
            self.state.write_batch(batch)
            self.update_child_metadata(
                block.headerhash
            )  # TODO: Not needed to execute when an orphan block is added
            logger.info('Added Block #%s %s', block.block_number,
                        bin2hstr(block.headerhash))
            return True

        return False

    def update_child_metadata(self, headerhash):
        block_metadata = self.state.get_block_metadata(headerhash)

        childs = list(block_metadata.child_headerhashes)

        while childs:
            child_headerhash = childs.pop(0)
            block = self.state.get_block(child_headerhash)
            if not block:
                continue
            if not self._add_block(block):
                self._prune([block.headerhash], None)
                continue
            block_metadata = self.state.get_block_metadata(child_headerhash)
            childs += block_metadata.child_headerhashes

    def _prune(self, childs, batch):
        while childs:
            child_headerhash = childs.pop(0)

            block_metadata = self.state.get_block_metadata(child_headerhash)
            childs += block_metadata.child_headerhashes

            self.state.delete(bin2hstr(child_headerhash).encode(), batch)
            self.state.delete(
                b'metadata_' + bin2hstr(child_headerhash).encode(), batch)

    def add_block_metadata(self, headerhash, block_timestamp,
                           parent_headerhash, batch):
        block_metadata = self.state.get_block_metadata(headerhash)
        if not block_metadata:
            block_metadata = BlockMetadata.create()

        parent_metadata = self.state.get_block_metadata(parent_headerhash)
        block_difficulty = (0, ) * 32  # 32 bytes to represent 256 bit of 0
        block_cumulative_difficulty = (
            0, ) * 32  # 32 bytes to represent 256 bit of 0
        if not parent_metadata:
            parent_metadata = BlockMetadata.create()
        else:
            parent_block = self.state.get_block(parent_headerhash)
            if parent_block:
                parent_block_difficulty = parent_metadata.block_difficulty
                parent_cumulative_difficulty = parent_metadata.cumulative_difficulty

                if not parent_metadata.is_orphan:
                    block_metadata.update_last_headerhashes(
                        parent_metadata.last_N_headerhashes, parent_headerhash)
                    measurement = self.state.get_measurement(
                        block_timestamp, parent_headerhash, parent_metadata)

                    block_difficulty, _ = DifficultyTracker.get(
                        measurement=measurement,
                        parent_difficulty=parent_block_difficulty)

                    block_cumulative_difficulty = StringToUInt256(
                        str(
                            int(UInt256ToString(block_difficulty)) +
                            int(UInt256ToString(parent_cumulative_difficulty)))
                    )

        block_metadata.set_orphan(parent_metadata.is_orphan)
        block_metadata.set_block_difficulty(block_difficulty)
        block_metadata.set_cumulative_difficulty(block_cumulative_difficulty)

        parent_metadata.add_child_headerhash(headerhash)
        self.state.put_block_metadata(parent_headerhash, parent_metadata,
                                      batch)
        self.state.put_block_metadata(headerhash, block_metadata, batch)

        # Call once to populate the cache
        self.state.get_block_datapoint(headerhash)

    def _update_mainchain(self, block, batch):
        block_number_mapping = None
        while block_number_mapping is None or block.headerhash != block_number_mapping.headerhash:
            block_number_mapping = qrl_pb2.BlockNumberMapping(
                headerhash=block.headerhash,
                prev_headerhash=block.prev_headerhash)
            self.state.put_block_number_mapping(block.block_number,
                                                block_number_mapping, batch)
            block = self.state.get_block(block.prev_headerhash)
            block_number_mapping = self.state.get_block_number_mapping(
                block.block_number)

    def get_block_by_number(self, block_number) -> Optional[Block]:
        return self.state.get_block_by_number(block_number)

    def get_state(self, headerhash):
        return self.state.get_state(headerhash, set())

    def get_address(self, address):
        return self.state.get_address_state(address)

    def get_transaction(self, transaction_hash) -> list:
        for tx_set in self.tx_pool.transactions:
            tx = tx_set[1].transaction
            if tx.txhash == transaction_hash:
                return [tx, None]

        return self.state.get_tx_metadata(transaction_hash)

    def get_headerhashes(self, start_blocknumber):
        start_blocknumber = max(0, start_blocknumber)
        end_blocknumber = min(self.last_block.block_number,
                              start_blocknumber + 2 * config.dev.reorg_limit)

        total_expected_headerhash = end_blocknumber - start_blocknumber + 1

        node_header_hash = qrl_pb2.NodeHeaderHash()
        node_header_hash.block_number = start_blocknumber

        block = self.state.get_block_by_number(end_blocknumber)
        block_headerhash = block.headerhash
        node_header_hash.headerhashes.append(block_headerhash)
        end_blocknumber -= 1

        while end_blocknumber >= start_blocknumber:
            block_metadata = self.state.get_block_metadata(block_headerhash)
            for headerhash in block_metadata.last_N_headerhashes[-1::-1]:
                node_header_hash.headerhashes.append(headerhash)
            end_blocknumber -= len(block_metadata.last_N_headerhashes)
            if len(block_metadata.last_N_headerhashes) == 0:
                break
            block_headerhash = block_metadata.last_N_headerhashes[0]

        node_header_hash.headerhashes[:] = node_header_hash.headerhashes[
            -1::-1]
        del node_header_hash.headerhashes[:len(node_header_hash.headerhashes) -
                                          total_expected_headerhash]

        return node_header_hash

    def add_ephemeral_message(self,
                              encrypted_ephemeral: EncryptedEphemeralMessage):
        self.state.update_ephemeral(encrypted_ephemeral)
Beispiel #2
0
class ChainManager:
    def __init__(self, state):
        self._state = state
        self.tx_pool = TransactionPool(None)
        self._last_block = Block.deserialize(GenesisBlock().serialize())
        self.current_difficulty = StringToUInt256(
            str(config.user.genesis_difficulty))

        self.trigger_miner = False
        self.lock = threading.RLock()

    @property
    def height(self):
        with self.lock:
            if not self._last_block:
                return -1
            return self._last_block.block_number

    @property
    def last_block(self) -> Block:
        with self.lock:
            return self._last_block

    @property
    def total_coin_supply(self):
        with self.lock:
            return self._state.total_coin_supply

    def get_block_datapoint(self, headerhash):
        with self.lock:
            return self._state.get_block_datapoint(headerhash)

    def get_cumulative_difficulty(self):
        with self.lock:
            last_block_metadata = self._state.get_block_metadata(
                self._last_block.headerhash)
            return last_block_metadata.cumulative_difficulty

    def get_block_by_number(self, block_number) -> Optional[Block]:
        with self.lock:
            return self._state.get_block_by_number(block_number)

    def get_block(self, header_hash: bytes) -> Optional[Block]:
        with self.lock:
            return self._state.get_block(header_hash)

    def get_address_balance(self, address: bytes) -> int:
        with self.lock:
            return self._state.get_address_balance(address)

    def get_address_is_used(self, address: bytes) -> bool:
        with self.lock:
            return self._state.get_address_is_used(address)

    def get_address_state(self, address: bytes) -> AddressState:
        with self.lock:
            return self._state.get_address_state(address)

    def get_all_address_state(self):
        with self.lock:
            return self._state.get_all_address_state()

    def get_tx_metadata(self, transaction_hash) -> list:
        with self.lock:
            return self._state.get_tx_metadata(transaction_hash)

    def get_last_transactions(self):
        with self.lock:
            return self._state.get_last_txs()

    def get_unconfirmed_transaction(self, transaction_hash) -> list:
        with self.lock:
            for tx_set in self.tx_pool.transactions:
                tx = tx_set[1].transaction
                if tx.txhash == transaction_hash:
                    return [tx, tx_set[1].timestamp]
            return []

    def get_block_metadata(self,
                           header_hash: bytes) -> Optional[BlockMetadata]:
        with self.lock:
            return self._state.get_block_metadata(header_hash)

    def get_blockheader_and_metadata(self, block_number=0) -> Tuple:
        with self.lock:
            block_number = block_number or self.height  # if both are non-zero, then block_number takes priority

            result = (None, None)
            block = self.get_block_by_number(block_number)
            if block:
                blockheader = block.blockheader
                blockmetadata = self.get_block_metadata(blockheader.headerhash)
                result = (blockheader, blockmetadata)

            return result

    def get_block_to_mine(self, miner, wallet_address) -> list:
        with miner.lock:  # Trying to acquire miner.lock to make sure pre_block_logic is not running
            with self.lock:
                last_block = self.last_block
                last_block_metadata = self.get_block_metadata(
                    last_block.headerhash)
                return miner.get_block_to_mine(
                    wallet_address, self.tx_pool, last_block,
                    last_block_metadata.block_difficulty)

    def get_measurement(self, block_timestamp, parent_headerhash,
                        parent_metadata: BlockMetadata):
        with self.lock:
            return self._state.get_measurement(block_timestamp,
                                               parent_headerhash,
                                               parent_metadata)

    def get_block_size_limit(self, block: Block):
        with self.lock:
            return self._state.get_block_size_limit(block)

    def get_block_is_duplicate(self, block: Block) -> bool:
        with self.lock:
            return self._state.get_block(block.headerhash) is not None

    def validate_mining_nonce(self,
                              blockheader: BlockHeader,
                              enable_logging=True):
        with self.lock:
            parent_metadata = self.get_block_metadata(
                blockheader.prev_headerhash)
            parent_block = self._state.get_block(blockheader.prev_headerhash)

            measurement = self.get_measurement(blockheader.timestamp,
                                               blockheader.prev_headerhash,
                                               parent_metadata)
            diff, target = DifficultyTracker.get(
                measurement=measurement,
                parent_difficulty=parent_metadata.block_difficulty)

            if enable_logging:
                logger.debug('-----------------START--------------------')
                logger.debug('Validate                #%s',
                             blockheader.block_number)
                logger.debug('block.timestamp         %s',
                             blockheader.timestamp)
                logger.debug('parent_block.timestamp  %s',
                             parent_block.timestamp)
                logger.debug('parent_block.difficulty %s',
                             UInt256ToString(parent_metadata.block_difficulty))
                logger.debug('diff                    %s',
                             UInt256ToString(diff))
                logger.debug('target                  %s', bin2hstr(target))
                logger.debug('-------------------END--------------------')

            if not PoWValidator().verify_input(blockheader.mining_blob,
                                               target):
                if enable_logging:
                    logger.warning("PoW verification failed")
                    qn = Qryptonight()
                    tmp_hash = qn.hash(blockheader.mining_blob)
                    logger.warning("{}".format(bin2hstr(tmp_hash)))
                    logger.debug('%s', blockheader.to_json())
                return False

            return True

    def get_headerhashes(self, start_blocknumber):
        with self.lock:
            start_blocknumber = max(0, start_blocknumber)
            end_blocknumber = min(
                self._last_block.block_number,
                start_blocknumber + 2 * config.dev.reorg_limit)

            total_expected_headerhash = end_blocknumber - start_blocknumber + 1

            node_header_hash = qrl_pb2.NodeHeaderHash()
            node_header_hash.block_number = start_blocknumber

            block = self._state.get_block_by_number(end_blocknumber)
            block_headerhash = block.headerhash
            node_header_hash.headerhashes.append(block_headerhash)
            end_blocknumber -= 1

            while end_blocknumber >= start_blocknumber:
                block_metadata = self._state.get_block_metadata(
                    block_headerhash)
                for headerhash in block_metadata.last_N_headerhashes[-1::-1]:
                    node_header_hash.headerhashes.append(headerhash)
                end_blocknumber -= len(block_metadata.last_N_headerhashes)
                if len(block_metadata.last_N_headerhashes) == 0:
                    break
                block_headerhash = block_metadata.last_N_headerhashes[0]

            node_header_hash.headerhashes[:] = node_header_hash.headerhashes[
                -1::-1]
            del node_header_hash.headerhashes[:len(node_header_hash.
                                                   headerhashes) -
                                              total_expected_headerhash]

            return node_header_hash

    def set_broadcast_tx(self, broadcast_tx):
        with self.lock:
            self.tx_pool.set_broadcast_tx(broadcast_tx)

    def load(self, genesis_block):
        # load() has the following tasks:
        # Write Genesis Block into State immediately
        # Register block_number <-> blockhash mapping
        # Calculate difficulty Metadata for Genesis Block
        # Generate AddressStates from Genesis Block balances
        # Apply Genesis Block's transactions to the state
        # Detect if we are forked from genesis block and if so initiate recovery.
        height = self._state.get_mainchain_height()

        if height == -1:
            self._state.put_block(genesis_block, None)
            block_number_mapping = qrl_pb2.BlockNumberMapping(
                headerhash=genesis_block.headerhash,
                prev_headerhash=genesis_block.prev_headerhash)

            self._state.put_block_number_mapping(genesis_block.block_number,
                                                 block_number_mapping, None)
            parent_difficulty = StringToUInt256(
                str(config.user.genesis_difficulty))

            self.current_difficulty, _ = DifficultyTracker.get(
                measurement=config.dev.mining_setpoint_blocktime,
                parent_difficulty=parent_difficulty)

            block_metadata = BlockMetadata.create()
            block_metadata.set_block_difficulty(self.current_difficulty)
            block_metadata.set_cumulative_difficulty(self.current_difficulty)

            self._state.put_block_metadata(genesis_block.headerhash,
                                           block_metadata, None)
            addresses_state = dict()
            for genesis_balance in GenesisBlock().genesis_balance:
                bytes_addr = genesis_balance.address
                addresses_state[bytes_addr] = AddressState.get_default(
                    bytes_addr)
                addresses_state[
                    bytes_addr]._data.balance = genesis_balance.balance

            for tx_idx in range(1, len(genesis_block.transactions)):
                tx = Transaction.from_pbdata(
                    genesis_block.transactions[tx_idx])
                for addr in tx.addrs_to:
                    addresses_state[addr] = AddressState.get_default(addr)

            coinbase_tx = Transaction.from_pbdata(
                genesis_block.transactions[0])

            if not isinstance(coinbase_tx, CoinBase):
                return False

            addresses_state[coinbase_tx.addr_to] = AddressState.get_default(
                coinbase_tx.addr_to)

            if not coinbase_tx.validate_extended(genesis_block.block_number):
                return False

            coinbase_tx.apply_state_changes(addresses_state)

            for tx_idx in range(1, len(genesis_block.transactions)):
                tx = Transaction.from_pbdata(
                    genesis_block.transactions[tx_idx])
                tx.apply_state_changes(addresses_state)

            self._state.put_addresses_state(addresses_state)
            self._state.update_tx_metadata(genesis_block, None)
            self._state.update_mainchain_height(0, None)
        else:
            self._last_block = self.get_block_by_number(height)
            self.current_difficulty = self._state.get_block_metadata(
                self._last_block.headerhash).block_difficulty
            fork_state = self._state.get_fork_state()
            if fork_state:
                block = self._state.get_block(fork_state.initiator_headerhash)
                self._fork_recovery(block, fork_state)

    def _apply_block(self, block: Block, batch) -> bool:
        address_set = self._state.prepare_address_list(
            block)  # Prepare list for current block
        addresses_state = self._state.get_state_mainchain(address_set)
        if not block.apply_state_changes(addresses_state):
            return False
        self._state.put_addresses_state(addresses_state, batch)
        return True

    def _update_chainstate(self, block: Block, batch):
        self._last_block = block
        self._update_block_number_mapping(block, batch)
        self.tx_pool.remove_tx_in_block_from_pool(block)
        self._state.update_mainchain_height(block.block_number, batch)
        self._state.update_tx_metadata(block, batch)

    def _try_branch_add_block(self,
                              block,
                              batch,
                              check_stale=True) -> (bool, bool):
        """
        This function returns list of bool types. The first bool represent
        if the block has been added successfully and the second bool
        represent the fork_flag, which becomes true when a block triggered
        into fork recovery.
        :param block:
        :param batch:
        :return: [Added successfully, fork_flag]
        """
        if self._last_block.headerhash == block.prev_headerhash:
            if not self._apply_block(block, batch):
                return False, False

        self._state.put_block(block, batch)

        last_block_metadata = self._state.get_block_metadata(
            self._last_block.headerhash)
        if last_block_metadata is None:
            logger.warning("Could not find log metadata for %s",
                           bin2hstr(self._last_block.headerhash))
            return False, False

        last_block_difficulty = int(
            UInt256ToString(last_block_metadata.cumulative_difficulty))

        new_block_metadata = self._add_block_metadata(block.headerhash,
                                                      block.timestamp,
                                                      block.prev_headerhash,
                                                      batch)
        new_block_difficulty = int(
            UInt256ToString(new_block_metadata.cumulative_difficulty))

        if new_block_difficulty > last_block_difficulty:
            if self._last_block.headerhash != block.prev_headerhash:
                fork_state = qrlstateinfo_pb2.ForkState(
                    initiator_headerhash=block.headerhash)
                self._state.put_fork_state(fork_state, batch)
                self._state.write_batch(batch)
                return self._fork_recovery(block, fork_state), True

            self._update_chainstate(block, batch)
            if check_stale:
                self.tx_pool.check_stale_txn(self._state, block.block_number)
            self.trigger_miner = True

        return True, False

    def _remove_block_from_mainchain(self, block: Block,
                                     latest_block_number: int, batch):
        addresses_set = self._state.prepare_address_list(block)
        addresses_state = self._state.get_state_mainchain(addresses_set)
        for tx_idx in range(len(block.transactions) - 1, -1, -1):
            tx = Transaction.from_pbdata(block.transactions[tx_idx])
            tx.revert_state_changes(addresses_state, self)

        self.tx_pool.add_tx_from_block_to_pool(block, latest_block_number)
        self._state.update_mainchain_height(block.block_number - 1, batch)
        self._state.rollback_tx_metadata(block, batch)
        self._state.remove_blocknumber_mapping(block.block_number, batch)
        self._state.put_addresses_state(addresses_state, batch)

    def _get_fork_point(self, block: Block):
        tmp_block = block
        hash_path = []
        while True:
            if not block:
                raise Exception('[get_state] No Block Found %s, Initiator %s',
                                block.headerhash, tmp_block.headerhash)
            mainchain_block = self.get_block_by_number(block.block_number)
            if mainchain_block and mainchain_block.headerhash == block.headerhash:
                break
            if block.block_number == 0:
                raise Exception(
                    '[get_state] Alternate chain genesis is different, Initiator %s',
                    tmp_block.headerhash)
            hash_path.append(block.headerhash)
            block = self._state.get_block(block.prev_headerhash)

        return block.headerhash, hash_path

    def _rollback(self,
                  forked_header_hash: bytes,
                  fork_state: qrlstateinfo_pb2.ForkState = None):
        """
        Rollback from last block to the block just before the forked_header_hash
        :param forked_header_hash:
        :param fork_state:
        :return:
        """
        hash_path = []
        while self._last_block.headerhash != forked_header_hash:
            block = self._state.get_block(self._last_block.headerhash)
            mainchain_block = self._state.get_block_by_number(
                block.block_number)

            if block is None:
                logger.warning(
                    "self.state.get_block(self.last_block.headerhash) returned None"
                )

            if mainchain_block is None:
                logger.warning(
                    "self.get_block_by_number(block.block_number) returned None"
                )

            if block.headerhash != mainchain_block.headerhash:
                break
            hash_path.append(self._last_block.headerhash)

            batch = self._state.batch
            self._remove_block_from_mainchain(self._last_block,
                                              block.block_number, batch)

            if fork_state:
                fork_state.old_mainchain_hash_path.extend(
                    [self._last_block.headerhash])
                self._state.put_fork_state(fork_state, batch)

            self._state.write_batch(batch)

            self._last_block = self._state.get_block(
                self._last_block.prev_headerhash)

        return hash_path

    def add_chain(self, hash_path: list,
                  fork_state: qrlstateinfo_pb2.ForkState) -> bool:
        """
        Add series of blocks whose headerhash mentioned into hash_path
        :param hash_path:
        :param fork_state:
        :param batch:
        :return:
        """
        with self.lock:
            start = 0
            try:
                start = hash_path.index(self._last_block.headerhash) + 1
            except ValueError:
                # Following condition can only be true if the fork recovery was interrupted last time
                if self._last_block.headerhash in fork_state.old_mainchain_hash_path:
                    return False

            for i in range(start, len(hash_path)):
                header_hash = hash_path[i]
                block = self._state.get_block(header_hash)

                batch = self._state.batch

                if not self._apply_block(block, batch):
                    return False

                self._update_chainstate(block, batch)

                logger.debug('Apply block #%d - [batch %d | %s]',
                             block.block_number, i, hash_path[i])
                self._state.write_batch(batch)

            self._state.delete_fork_state()

            return True

    def _fork_recovery(self, block: Block,
                       fork_state: qrlstateinfo_pb2.ForkState) -> bool:
        logger.info("Triggered Fork Recovery")
        # This condition only becomes true, when fork recovery was interrupted
        if fork_state.fork_point_headerhash:
            logger.info("Recovering from last fork recovery interruption")
            forked_header_hash, hash_path = fork_state.fork_point_headerhash, fork_state.new_mainchain_hash_path
        else:
            forked_header_hash, hash_path = self._get_fork_point(block)
            fork_state.fork_point_headerhash = forked_header_hash
            fork_state.new_mainchain_hash_path.extend(hash_path)
            self._state.put_fork_state(fork_state)

        rollback_done = False
        if fork_state.old_mainchain_hash_path:
            b = self._state.get_block(fork_state.old_mainchain_hash_path[-1])
            if b and b.prev_headerhash == fork_state.fork_point_headerhash:
                rollback_done = True

        if not rollback_done:
            logger.info("Rolling back")
            old_hash_path = self._rollback(forked_header_hash, fork_state)
        else:
            old_hash_path = fork_state.old_mainchain_hash_path

        if not self.add_chain(hash_path[-1::-1], fork_state):
            logger.warning(
                "Fork Recovery Failed... Recovering back to old mainchain")
            # If above condition is true, then it means, the node failed to add_chain
            # Thus old chain state, must be retrieved
            self._rollback(forked_header_hash)
            self.add_chain(old_hash_path[-1::-1],
                           fork_state)  # Restores the old chain state
            return False

        logger.info("Fork Recovery Finished")

        self.trigger_miner = True
        return True

    def _add_block(self, block, batch=None, check_stale=True) -> (bool, bool):
        self.trigger_miner = False

        block_size_limit = self.get_block_size_limit(block)
        if block_size_limit and block.size > block_size_limit:
            logger.info('Block Size greater than threshold limit %s > %s',
                        block.size, block_size_limit)
            return False, False

        return self._try_branch_add_block(block, batch, check_stale)

    def add_block(self, block: Block, check_stale=True) -> bool:
        with self.lock:
            if block.block_number < self.height - config.dev.reorg_limit:
                logger.debug('Skipping block #%s as beyond re-org limit',
                             block.block_number)
                return False

            if self.get_block_is_duplicate(block):
                return False

            batch = self._state.batch
            block_flag, fork_flag = self._add_block(block,
                                                    batch=batch,
                                                    check_stale=check_stale)
            if block_flag:
                if not fork_flag:
                    self._state.write_batch(batch)
                logger.info('Added Block #%s %s', block.block_number,
                            bin2hstr(block.headerhash))
                return True

            return False

    def _add_block_metadata(self, headerhash, block_timestamp,
                            parent_headerhash, batch):
        block_metadata = self._state.get_block_metadata(headerhash)
        if not block_metadata:
            block_metadata = BlockMetadata.create()

        parent_metadata = self._state.get_block_metadata(parent_headerhash)

        parent_block_difficulty = parent_metadata.block_difficulty
        parent_cumulative_difficulty = parent_metadata.cumulative_difficulty

        block_metadata.update_last_headerhashes(
            parent_metadata.last_N_headerhashes, parent_headerhash)
        measurement = self._state.get_measurement(block_timestamp,
                                                  parent_headerhash,
                                                  parent_metadata)

        block_difficulty, _ = DifficultyTracker.get(
            measurement=measurement, parent_difficulty=parent_block_difficulty)

        block_cumulative_difficulty = StringToUInt256(
            str(
                int(UInt256ToString(block_difficulty)) +
                int(UInt256ToString(parent_cumulative_difficulty))))

        block_metadata.set_block_difficulty(block_difficulty)
        block_metadata.set_cumulative_difficulty(block_cumulative_difficulty)

        parent_metadata.add_child_headerhash(headerhash)
        self._state.put_block_metadata(parent_headerhash, parent_metadata,
                                       batch)
        self._state.put_block_metadata(headerhash, block_metadata, batch)

        return block_metadata

    def _update_block_number_mapping(self, block, batch):
        block_number_mapping = qrl_pb2.BlockNumberMapping(
            headerhash=block.headerhash, prev_headerhash=block.prev_headerhash)
        self._state.put_block_number_mapping(block.block_number,
                                             block_number_mapping, batch)
class TestTransactionPool(TestCase):
    """
    TransactionPool sits between incoming Transactions from the network and Blocks.
    First, incoming Transactions are pending Transactions and go into TransactionPool.pending_tx_pool.
    The TxnProcessor has to validate them. Once they are validated, the TxnProcessor puts them into
    TransactionPool.transaction_pool, where they wait to be put into the next mined Block.
    """
    def setUp(self):
        self.txpool = TransactionPool(None)

    def test_add_tx_to_pool(self):
        tx = make_tx()
        result = self.txpool.add_tx_to_pool(tx, 1, replacement_getTime())
        self.assertTrue(result)
        self.assertEqual(len(self.txpool.transactions), 1)

    @patch('qrl.core.TransactionPool.TransactionPool.is_full_transaction_pool',
           autospec=True)
    def test_add_tx_to_pool_while_full(self, m_is_full_func):
        m_is_full_func.return_value = True
        tx = make_tx()
        result = self.txpool.add_tx_to_pool(tx, 1, replacement_getTime())
        self.assertFalse(result)  # refused to add to the pool
        self.assertEqual(len(self.txpool.transactions), 0)  # remains untouched

    @patch('qrl.core.TransactionPool.config', autospec=True)
    def test_is_full_transaction_pool(self, m_config):
        m_config.user.transaction_pool_size = 2

        result = self.txpool.is_full_transaction_pool()
        self.assertFalse(result)

        tx1 = make_tx(fee=1)
        tx2 = make_tx(fee=2)

        self.txpool.add_tx_to_pool(tx1, 1, replacement_getTime())
        self.txpool.add_tx_to_pool(tx2, 1, replacement_getTime())

        result = self.txpool.is_full_transaction_pool()
        self.assertTrue(result)

    def test_get_tx_index_from_pool(self):
        tx1 = make_tx(txhash=b'red')
        tx2 = make_tx(txhash=b'blue')
        tx3 = make_tx(txhash=b'qrlpink')

        self.txpool.add_tx_to_pool(tx1, 1, replacement_getTime())
        self.txpool.add_tx_to_pool(tx2, 1, replacement_getTime())
        self.txpool.add_tx_to_pool(tx3, 1, replacement_getTime())

        idx = self.txpool.get_tx_index_from_pool(b'qrlpink')
        self.assertEqual(idx, 2)

        idx = self.txpool.get_tx_index_from_pool(b'red')
        self.assertEqual(idx, 0)

        idx = self.txpool.get_tx_index_from_pool(b'ultraviolet')
        self.assertEqual(idx, -1)

    def test_remove_tx_from_pool(self):
        tx1 = make_tx(txhash=b'red')
        tx2 = make_tx(txhash=b'blue')
        tx3 = make_tx(txhash=b'qrlpink')

        self.txpool.add_tx_to_pool(tx1, 1, replacement_getTime())

        # If we try to remove a tx that wasn't there, the transaction pool should be untouched
        self.assertEqual(len(self.txpool.transaction_pool), 1)
        self.txpool.remove_tx_from_pool(tx2)
        self.assertEqual(len(self.txpool.transaction_pool), 1)

        # Now let's remove a tx from the heap. The size should decrease.
        self.txpool.add_tx_to_pool(tx2, 1, replacement_getTime())
        self.txpool.add_tx_to_pool(tx3, 1, replacement_getTime())

        self.assertEqual(len(self.txpool.transaction_pool), 3)
        self.txpool.remove_tx_from_pool(tx2)
        self.assertEqual(len(self.txpool.transaction_pool), 2)

    @patch(
        'qrl.core.TransactionPool.TransactionPool.is_full_pending_transaction_pool',
        autospec=True)
    def test_update_pending_tx_pool(self, m_is_full_pending_transaction_pool):
        tx1 = make_tx()
        ip = '127.0.0.1'
        m_is_full_pending_transaction_pool.return_value = False

        # Due to the straightforward way the function is written, no special setup is needed to get the tx to go in.
        result = self.txpool.update_pending_tx_pool(tx1, ip)
        self.assertTrue(result)

        # If we try to re-add the same tx to the pending_tx_pool, though, it should fail.
        result = self.txpool.update_pending_tx_pool(tx1, ip)
        self.assertFalse(result)

    @patch(
        'qrl.core.TransactionPool.TransactionPool.is_full_pending_transaction_pool',
        autospec=True)
    def test_update_pending_tx_pool_tx_already_validated(
            self, m_is_full_pending_transaction_pool):
        """
        If the tx is already in TransactionPool.transaction_pool, do not add it to pending_tx_pool.
        """
        tx1 = make_tx()
        ip = '127.0.0.1'
        m_is_full_pending_transaction_pool.return_value = False

        self.txpool.add_tx_to_pool(tx1, 1, replacement_getTime())

        result = self.txpool.update_pending_tx_pool(tx1, ip)
        self.assertFalse(result)

    @patch(
        'qrl.core.TransactionPool.TransactionPool.is_full_pending_transaction_pool',
        autospec=True)
    def test_update_pending_tx_pool_is_full_already(
            self, m_is_full_pending_transaction_pool):
        tx1 = make_tx()
        ip = '127.0.0.1'
        m_is_full_pending_transaction_pool.return_value = True

        result = self.txpool.update_pending_tx_pool(tx1, ip)
        self.assertFalse(result)

    @patch('qrl.core.TransactionPool.logger')
    @patch(
        'qrl.core.TransactionPool.TransactionPool.is_full_pending_transaction_pool',
        autospec=True)
    def test_update_pending_tx_pool_rejects_coinbase_txs(
            self, m_is_full_pending_transaction_pool, m_logger):
        tx1 = CoinBase()
        ip = '127.0.0.1'
        m_is_full_pending_transaction_pool.return_value = False

        result = self.txpool.update_pending_tx_pool(tx1, ip)
        self.assertFalse(result)

    @patch('qrl.core.TransactionPool.config', autospec=True)
    def test_is_full_pending_transaction_pool(self, m_config):
        """
        pending_transaction_pool_size is 3, and pending_transaction_pool_reserve is subtracted out of that, so it's 2.
        Trying to add in 3 transactions with ignore_reserve=True will fail, but if ignore_reserve=False, it will go in.
        However, after that, adding even more transactions will always fail.
        """
        m_config.user.pending_transaction_pool_size = 3
        m_config.user.pending_transaction_pool_reserve = 1

        tx4 = make_tx(txhash=b'red')
        tx1 = make_tx(txhash=b'green')
        tx3 = make_tx(txhash=b'blue')
        tx2 = make_tx(txhash=b'pink')
        ip = '127.0.0.1'

        self.txpool.update_pending_tx_pool(tx1, ip)
        self.txpool.update_pending_tx_pool(tx2, ip)
        result = self.txpool.update_pending_tx_pool(tx3,
                                                    ip,
                                                    ignore_reserve=True)
        self.assertFalse(result)
        result = self.txpool.update_pending_tx_pool(tx3,
                                                    ip,
                                                    ignore_reserve=False)
        self.assertTrue(result)

        result = self.txpool.update_pending_tx_pool(tx4,
                                                    ip,
                                                    ignore_reserve=True)
        self.assertFalse(result)
        result = self.txpool.update_pending_tx_pool(tx4,
                                                    ip,
                                                    ignore_reserve=False)
        self.assertFalse(result)

    @patch('qrl.core.misc.ntp.getTime', new=replacement_getTime)
    def test_get_pending_transaction(self):
        """
        Getting a pending transaction also removes it from the TransactionPool.
        Because it may return a single None, or two variables, a funny hack is used in TxnProcessor where the return
        from this function is stored in one variable then unpacked later if it is not None.
        """
        tx1 = make_tx()
        ip = '127.0.0.1'
        self.txpool.update_pending_tx_pool(tx1, ip)

        self.assertEqual(len(self.txpool.pending_tx_pool_hash), 1)
        tx_timestamp = self.txpool.get_pending_transaction()
        self.assertEqual(tx_timestamp[0], tx1)
        self.assertEqual(len(self.txpool.pending_tx_pool_hash), 0)

        tx_timestamp = self.txpool.get_pending_transaction()
        self.assertIsNone(tx_timestamp)

    @patch('qrl.core.TransactionPool.logger')
    @patch('qrl.core.txs.Transaction.Transaction.from_pbdata',
           return_value=make_tx())
    @patch('qrl.core.TransactionPool.TransactionPool.add_tx_to_pool',
           return_value=True)
    def test_add_tx_from_block_to_pool(self, m_add_tx_to_pool, m_from_pbdata,
                                       m_logger):
        m_block = Mock(autospec=Block,
                       block_number=5,
                       headerhash=b'test block header')
        m_block.transactions = [CoinBase(), make_tx(), make_tx()]

        self.txpool.add_tx_from_block_to_pool(m_block, 5)

        self.assertEqual(m_add_tx_to_pool.call_count,
                         2)  # 2 because the function ignores the Coinbase tx

        # If there is a problem adding to the tx_pool, the logger should be invoked.
        m_add_tx_to_pool.return_value = False
        self.txpool.add_tx_from_block_to_pool(m_block, 5)
        m_logger.warning.assert_called()

    @patch('qrl.core.txs.Transaction.Transaction.from_pbdata',
           new=replacement_from_pbdata)
    def test_remove_tx_in_block_from_pool(self):
        m_block = Mock(autospec=Block)
        tx1 = make_tx(name='Mock TX 1', ots_key=1, PK=b'pk')
        tx2 = make_tx(name='Mock TX 2', ots_key=2, PK=b'pk')
        m_block.transactions = [CoinBase(), tx1, tx2]

        # To remove the tx from the pool we have to add it first!
        self.txpool.add_tx_to_pool(tx1, 5)
        self.txpool.add_tx_to_pool(tx2, 5)
        self.assertEqual(len(self.txpool.transaction_pool), 2)

        self.txpool.remove_tx_in_block_from_pool(m_block)
        self.assertEqual(len(self.txpool.transaction_pool), 0)

    @patch('qrl.core.TransactionInfo.config', autospec=True)
    @patch('qrl.core.TransactionPool.TransactionPool.is_full_transaction_pool',
           return_value=False)
    def test_check_stale_txn(self, m_is_full_transaction_pool, m_config):
        """
        Stale Transactions are Transactions that were supposed to go into block 5, but for some reason didn't make it.
        They languish in TransactionPool until check_stale_txn() checks the Pool and updates the tx_info to make them
        go into a higher block.
        For each stale transaction, P2PFactory.broadcast_tx() will be called.
        """

        # Redefine at what point should txs be considered stale
        m_config.user.stale_transaction_threshold = 2
        bob_xmss = get_bob_xmss(4)
        alice_xmss = get_alice_xmss(4)

        tx1 = TransferTransaction.create(addrs_to=[bob_xmss.address],
                                         amounts=[1000000],
                                         fee=1,
                                         xmss_pk=alice_xmss.pk)
        tx1.sign(alice_xmss)
        tx2 = TransferTransaction.create(addrs_to=[bob_xmss.address],
                                         amounts=[10000],
                                         fee=1,
                                         xmss_pk=alice_xmss.pk)
        tx2.sign(alice_xmss)
        m_broadcast_tx = Mock(
            name='Mock Broadcast TX function (in P2PFactory)')
        self.txpool.add_tx_to_pool(tx1, 5)
        self.txpool.add_tx_to_pool(tx2, 5)
        self.txpool.set_broadcast_tx(m_broadcast_tx)

        with set_qrl_dir('no_data'):
            state = State()
            self.txpool.check_stale_txn(state, 8)

            self.assertEqual(m_broadcast_tx.call_count, 0)

            m = MockFunction()
            bob_address_state = AddressState.get_default(bob_xmss.address)
            bob_address_state.pbdata.balance = 1000000000000
            m.put(bob_xmss.address, bob_address_state)
            state.get_address_state = m.get
            tx3 = TransferTransaction.create(addrs_to=[alice_xmss.address],
                                             amounts=[10000],
                                             fee=1,
                                             xmss_pk=bob_xmss.pk)
            tx3.sign(bob_xmss)
            self.txpool.add_tx_to_pool(tx3, 5)
            self.txpool.check_stale_txn(state, 8)

            self.assertEqual(m_broadcast_tx.call_count, 1)