예제 #1
0
class QRLNode:
    def __init__(self, db_state: State, mining_address: bytes):
        self.start_time = time.time()
        self.db_state = db_state
        self._sync_state = SyncState()

        self.peer_manager = P2PPeerManager()
        self.peer_manager.load_peer_addresses()
        self.peer_manager.register(P2PPeerManager.EventType.NO_PEERS,
                                   self.connect_peers)

        self.p2pchain_manager = P2PChainManager()

        self.tx_manager = P2PTxManagement()

        self._chain_manager = None  # FIXME: REMOVE. This is temporary
        self._p2pfactory = None  # FIXME: REMOVE. This is temporary

        self._pow = None

        self.mining_address = mining_address

        banned_peers_filename = os.path.join(config.user.wallet_dir,
                                             config.dev.banned_peers_filename)
        self._banned_peers = ExpiringSet(
            expiration_time=config.user.ban_minutes * 60,
            filename=banned_peers_filename)

        reactor.callLater(10, self.monitor_chain_state)

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    @property
    def version(self):
        # FIXME: Move to __version__ coming from pip
        return config.dev.version

    @property
    def sync_state(self) -> SyncState:
        return self._sync_state

    @property
    def state(self):
        if self._p2pfactory is None:
            return ESyncState.unknown.value
        # FIXME
        return self._p2pfactory.sync_state.state.value

    @property
    def num_connections(self):
        if self._p2pfactory is None:
            return 0
        # FIXME
        return self._p2pfactory.connections

    @property
    def num_known_peers(self):
        # FIXME
        return len(self.peer_addresses)

    @property
    def uptime(self):
        return int(time.time() - self.start_time)

    @property
    def block_height(self):
        return self._chain_manager.height

    @property
    def epoch(self):
        if not self._chain_manager.get_last_block():
            return 0
        return self._chain_manager.get_last_block(
        ).block_number // config.dev.blocks_per_epoch

    @property
    def uptime_network(self):
        block_one = self._chain_manager.get_block_by_number(1)
        network_uptime = 0
        if block_one:
            network_uptime = int(time.time() - block_one.timestamp)
        return network_uptime

    @property
    def block_last_reward(self):
        if not self._chain_manager.get_last_block():
            return 0

        return self._chain_manager.get_last_block().block_reward

    @property
    def block_time_mean(self):
        block = self._chain_manager.get_last_block()

        prev_block_metadata = self._chain_manager.state.get_block_metadata(
            block.prev_headerhash)
        if prev_block_metadata is None:
            return config.dev.mining_setpoint_blocktime

        movavg = self._chain_manager.state.get_measurement(
            block.timestamp, block.prev_headerhash, prev_block_metadata)
        return movavg

    @property
    def block_time_sd(self):
        # FIXME: Keep a moving var
        return 0

    @property
    def coin_supply(self):
        # FIXME: Keep a moving var
        return self.db_state.total_coin_supply()

    @property
    def coin_supply_max(self):
        # FIXME: Keep a moving var
        return config.dev.max_coin_supply

    @property
    def peer_addresses(self):
        return self.peer_manager._peer_addresses

    ####################################################
    ####################################################
    ####################################################
    ####################################################
    def is_banned(self, addr_remote: str):
        return addr_remote in self._banned_peers

    def ban_peer(self, peer_obj):
        self._banned_peers.add(peer_obj.addr_remote)
        logger.warning('Banned %s', peer_obj.addr_remote)
        peer_obj.loseConnection()

    def connect_peers(self):
        logger.info('<<<Reconnecting to peer list: %s', self.peer_addresses)
        for peer_address in self.peer_addresses:
            if self.is_banned(peer_address):
                continue
            self._p2pfactory.connect_peer(peer_address)

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    def monitor_chain_state(self):
        self.peer_manager.monitor_chain_state()

        last_block = self._chain_manager.get_last_block()
        block_metadata = self.db_state.get_block_metadata(
            last_block.headerhash)
        node_chain_state = qrl_pb2.NodeChainState(
            block_number=last_block.block_number,
            header_hash=last_block.headerhash,
            cumulative_difficulty=bytes(block_metadata.cumulative_difficulty),
            timestamp=int(time.time()))

        self.peer_manager.broadcast_chain_state(
            node_chain_state=node_chain_state)
        channel = self.peer_manager.get_better_difficulty(
            block_metadata.cumulative_difficulty)
        logger.debug('Got better difficulty %s', channel)
        if channel:
            logger.debug('Connection id >> %s', channel.addr_remote)
            channel.get_headerhash_list(self._chain_manager.height)
        reactor.callLater(config.user.chain_state_broadcast_period,
                          self.monitor_chain_state)

    # FIXME: REMOVE. This is temporary
    def set_chain_manager(self, chain_manager: ChainManager):
        self._chain_manager = chain_manager

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    def start_pow(self, mining_thread_count):
        self._pow = POW(chain_manager=self._chain_manager,
                        p2p_factory=self._p2pfactory,
                        sync_state=self._sync_state,
                        time_provider=ntp,
                        mining_address=self.mining_address,
                        mining_thread_count=mining_thread_count)

        self._pow.start()

    def start_listening(self):
        self._p2pfactory = P2PFactory(
            chain_manager=self._chain_manager,
            sync_state=self.sync_state,
            qrl_node=self)  # FIXME: Try to avoid cyclic references

        self._p2pfactory.start_listening()

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    @staticmethod
    def validate_amount(amount_str: str) -> bool:
        # FIXME: Refactored code. Review Decimal usage all over the code
        Decimal(amount_str)
        return True

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    @staticmethod
    def create_token_txn(symbol: bytes, name: bytes, owner: bytes,
                         decimals: int, initial_balances, fee: int,
                         xmss_pk: bytes, master_addr: bytes):
        return TokenTransaction.create(symbol, name, owner, decimals,
                                       initial_balances, fee, xmss_pk,
                                       master_addr)

    @staticmethod
    def create_transfer_token_txn(addrs_to: list, token_txhash: bytes,
                                  amounts: list, fee: int, xmss_pk: bytes,
                                  master_addr: bytes):
        return TransferTokenTransaction.create(token_txhash, addrs_to, amounts,
                                               fee, xmss_pk, master_addr)

    def create_send_tx(self, addrs_to: list, amounts: list, fee: int,
                       xmss_pk: bytes,
                       master_addr: bytes) -> TransferTransaction:
        addr_from = self.get_addr_from(xmss_pk, master_addr)
        balance = self.db_state.balance(addr_from)
        if sum(amounts) + fee > balance:
            raise ValueError("Not enough funds in the source address")

        return TransferTransaction.create(addrs_to=addrs_to,
                                          amounts=amounts,
                                          fee=fee,
                                          xmss_pk=xmss_pk,
                                          master_addr=master_addr)

    def create_slave_tx(self, slave_pks: list, access_types: list, fee: int,
                        xmss_pk: bytes,
                        master_addr: bytes) -> SlaveTransaction:
        return SlaveTransaction.create(slave_pks=slave_pks,
                                       access_types=access_types,
                                       fee=fee,
                                       xmss_pk=xmss_pk,
                                       master_addr=master_addr)

    def create_lattice_public_key_txn(self, kyber_pk: bytes,
                                      dilithium_pk: bytes, fee: int,
                                      xmss_pk: bytes,
                                      master_addr: bytes) -> SlaveTransaction:
        return LatticePublicKey.create(kyber_pk=kyber_pk,
                                       dilithium_pk=dilithium_pk,
                                       fee=fee,
                                       xmss_pk=xmss_pk,
                                       master_addr=master_addr)

    # FIXME: Rename this appropriately
    def submit_send_tx(self, tx) -> bool:
        if tx is None:
            raise ValueError("The transaction was empty")

        if self._chain_manager.tx_pool.is_full_pending_transaction_pool():
            raise ValueError("Pending Transaction Pool is full")

        return self._p2pfactory.add_unprocessed_txn(
            tx,
            ip=None)  # TODO (cyyber): Replace None with IP made API request

    @staticmethod
    def get_addr_from(xmss_pk, master_addr):
        if master_addr:
            return master_addr

        return bytes(QRLHelper.getAddress(xmss_pk))

    def get_address_is_used(self, address: bytes) -> bool:
        if not AddressState.address_is_valid(address):
            raise ValueError("Invalid Address")

        return self.db_state.address_used(address)

    def get_address_state(self, address: bytes) -> qrl_pb2.AddressState:
        if not AddressState.address_is_valid(address):
            raise ValueError("Invalid Address")

        address_state = self.db_state.get_address_state(address)

        return address_state

    def get_transaction(self, query_hash: bytes):
        """
        This method returns an object that matches the query hash
        """
        # FIXME: At some point, all objects in DB will indexed by a hash
        # TODO: Search tx hash
        # FIXME: We dont need searches, etc.. getting a protobuf indexed by hash from DB should be enough
        # FIXME: This is just a workaround to provide functionality
        result = self._chain_manager.get_transaction(query_hash)
        if result:
            return result[0], result[1]
        return None, None

    def get_block_last(self) -> Optional[Block]:
        """
        This method returns an object that matches the query hash
        """
        return self._chain_manager.get_last_block()

    def get_block_from_hash(self, query_hash: bytes) -> Optional[Block]:
        """
        This method returns an object that matches the query hash
        """
        return self.db_state.get_block(query_hash)

    def get_block_from_index(self, index: int) -> Block:
        """
        This method returns an object that matches the query hash
        """
        return self.db_state.get_block_by_number(index)

    def get_blockidx_from_txhash(self, transaction_hash):
        result = self.db_state.get_tx_metadata(transaction_hash)
        if result:
            return result[1]
        return None

    def get_token_detailed_list(self):
        pbdata = self.db_state.get_token_list()
        token_list = TokenList.from_json(pbdata)
        token_detailed_list = qrl_pb2.TokenDetailedList()
        for token_txhash in token_list.token_txhash:
            token_txn, _ = self.db_state.get_tx_metadata(token_txhash)
            transaction_extended = qrl_pb2.TransactionExtended(
                tx=token_txn.pbdata, addr_from=token_txhash.addr_from)
            token_detailed_list.extended_tokens.extend([transaction_extended])
        return token_detailed_list

    def get_latest_blocks(self, offset, count) -> List[Block]:
        answer = []
        end = self.block_height - offset
        start = max(0, end - count - offset)
        for blk_idx in range(start, end + 1):
            answer.append(self._chain_manager.get_block_by_number(blk_idx))

        return answer

    def get_latest_transactions(self, offset, count):
        answer = []
        skipped = 0
        for tx in self.db_state.get_last_txs():
            if skipped >= offset:
                answer.append(tx)
                if len(answer) >= count:
                    break
            else:
                skipped += 1

        return answer

    def get_latest_transactions_unconfirmed(self, offset, count):
        answer = []
        skipped = 0
        for tx_set in self._chain_manager.tx_pool.transactions:
            if skipped >= offset:
                answer.append(tx_set[1].transaction)
                if len(answer) >= count:
                    break
            else:
                skipped += 1
        return answer

    def getNodeInfo(self) -> qrl_pb2.NodeInfo:
        info = qrl_pb2.NodeInfo()
        info.version = self.version
        info.state = self.state
        info.num_connections = self.num_connections
        info.num_known_peers = self.num_known_peers
        info.uptime = self.uptime
        info.block_height = self.block_height
        info.block_last_hash = self._chain_manager.get_last_block(
        ).headerhash  # FIXME
        info.network_id = config.dev.genesis_prev_headerhash  # FIXME
        return info

    def get_block_timeseries(self,
                             block_count) -> Iterator[qrl_pb2.BlockDataPoint]:
        result = []

        if self._chain_manager.height == 0:
            return result

        block = self._chain_manager.get_last_block()
        if block is None:
            return result

        headerhash_current = block.headerhash
        while len(result) < block_count:
            data_point = self._chain_manager.state.get_block_datapoint(
                headerhash_current)

            if data_point is None:
                break

            result.append(data_point)
            headerhash_current = data_point.header_hash_prev

        return reversed(result)

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    def broadcast_ephemeral_message(
            self, encrypted_ephemeral: EncryptedEphemeralMessage) -> bool:
        if not encrypted_ephemeral.validate():
            return False

        self._p2pfactory.broadcast_ephemeral_message(encrypted_ephemeral)

        return True

    def collect_ephemeral_message(self, msg_id):
        return self.db_state.get_ephemeral_metadata(msg_id)

    ####################################################
    ####################################################
    ####################################################
    ####################################################

    def get_blockheader_and_metadata(self, block_number) -> list:
        if block_number == 0:
            block_number = self.block_height

        result = []
        block = self.get_block_from_index(block_number)
        if block:
            blockheader = block.blockheader
            blockmetadata = self.db_state.get_block_metadata(
                blockheader.headerhash)
            result = [blockheader, blockmetadata]

        return result

    def get_block_to_mine(self, wallet_address) -> list:
        last_block = self._chain_manager.get_last_block()
        last_block_metadata = self._chain_manager.state.get_block_metadata(
            last_block.headerhash)
        return self._pow.miner.get_block_to_mine(
            wallet_address, self._chain_manager.tx_pool, last_block,
            last_block_metadata.block_difficulty)

    def submit_mined_block(self, blob) -> bool:
        return self._pow.miner.submit_mined_block(blob)
예제 #2
0
class P2PPeerManager(P2PBaseObserver):
    class EventType(Enum):
        NO_PEERS = 1

    def __init__(self):
        super().__init__()
        self._ping_callLater = None
        self._disconnect_callLater = None
        self._channels = []

        self._peer_node_status = dict()

        self._known_peers = set()
        self.peers_path = os.path.join(config.user.data_dir,
                                       config.dev.peers_filename)

        self.banned_peers_filename = os.path.join(config.user.wallet_dir, config.dev.banned_peers_filename)
        self._banned_peer_ips = ExpiringSet(expiration_time=config.user.ban_minutes * 60,
                                            filename=self.banned_peers_filename)

        self._observable = Observable(self)
        self._p2p_factory = None

    def register(self, message_type: EventType, func: Callable):
        self._observable.register(message_type, func)

    def set_p2p_factory(self, p2p_factory):
        self._p2p_factory = p2p_factory

    @property
    def known_peer_addresses(self):
        return self._known_peers

    def trusted_peer(self, channel: P2PProtocol):
        if self.is_banned(channel.peer):
            return False

        if channel.valid_message_count < config.dev.trust_min_msgcount:
            return False

        if channel.connection_time < config.dev.trust_min_conntime:
            return False

        return True

    @property
    def trusted_addresses(self):
        ip_public_port_set = set()
        for peer in self._p2p_factory.connections:
            if self.trusted_peer(peer) and peer.public_port != 0:
                ip_public_port_set.add(peer.ip_public_port)
        return ip_public_port_set

    @property
    def peer_node_status(self):
        return self._peer_node_status

    def load_known_peers(self) -> List[str]:
        known_peers = []
        try:
            logger.info('Loading known peers')
            with open(self.peers_path, 'r') as infile:
                known_peers = json.load(infile)
        except Exception as e:
            logger.info("Could not open known_peers list")

        return [IPMetadata.canonical_full_address(fa) for fa in known_peers]

    def save_known_peers(self, known_peers: List[str]):
        tmp = list(known_peers)[:3 * config.user.max_peers_limit]
        config.create_path(config.user.data_dir)
        with open(self.peers_path, 'w') as outfile:
            json.dump(tmp, outfile)

    def load_peer_addresses(self) -> None:
        known_peers = self.load_known_peers()
        self._known_peers = self.combine_peer_lists(known_peers, config.user.peer_list, )
        logger.info('Loaded known peers: %s', self._known_peers)
        self.save_known_peers(self._known_peers)

    def extend_known_peers(self, new_peer_addresses: set) -> None:
        new_addresses = set(new_peer_addresses) - self._known_peers

        if self._p2p_factory is not None:
            self._p2p_factory.connect_peer(new_addresses)

        self._known_peers |= set(new_peer_addresses)
        self.save_known_peers(list(self._known_peers))

    @staticmethod
    def combine_peer_lists(peer_ips, sender_full_addresses: List, check_global=False) -> Set[IPMetadata]:
        tmp_list = list(peer_ips)
        tmp_list.extend(sender_full_addresses)

        answer = set()
        for item in tmp_list:
            try:
                answer.add(IPMetadata.canonical_full_address(item, check_global))
            except:  # noqa
                logger.warning("Invalid Peer Address {}".format(item))

        return answer

    def get_better_difficulty(self, current_cumulative_difficulty):
        best_cumulative_difficulty = int(UInt256ToString(current_cumulative_difficulty))
        local_best = best_cumulative_difficulty
        best_channel = None

        for channel in self._peer_node_status:
            node_chain_state = self._peer_node_status[channel]
            node_cumulative_difficulty = int(UInt256ToString(node_chain_state.cumulative_difficulty))
            if node_cumulative_difficulty > best_cumulative_difficulty:
                best_cumulative_difficulty = node_cumulative_difficulty
                best_channel = channel
        logger.debug('Local Best Diff : %s', local_best)
        logger.debug('Remote Best Diff : %s', best_cumulative_difficulty)
        return best_channel

    def insert_to_last_connected_peer(self, ip_public_port, connected_peer=False):
        known_peers = self.load_known_peers()
        connection_set = set()

        if self._p2p_factory is not None:
            # Prepare set of connected peers
            for conn in self._p2p_factory._peer_connections:
                connection_set.add(conn.ip_public_port)

        # Move the current peer to the last position of connected peers
        # or to the start position of disconnected peers
        try:
            index = 0
            if connected_peer:
                if ip_public_port in known_peers:
                    known_peers.remove(ip_public_port)
            else:
                index = known_peers.index(ip_public_port)
                del known_peers[index]

            while index < len(known_peers):
                if known_peers[index] not in connection_set:
                    break
                index += 1
            known_peers.insert(index, ip_public_port)
            self.save_known_peers(known_peers)
        except ValueError:
            pass

    def remove_channel(self, channel):
        self.insert_to_last_connected_peer(channel.ip_public_port)
        if channel in self._channels:
            self._channels.remove(channel)
        if channel in self._peer_node_status:
            del self._peer_node_status[channel]

    def new_channel(self, channel):
        self._channels.append(channel)
        self._peer_node_status[channel] = qrl_pb2.NodeChainState(block_number=0,
                                                                 header_hash=b'',
                                                                 cumulative_difficulty=b'\x00' * 32,
                                                                 timestamp=ntp.getTime())
        channel.register(qrllegacy_pb2.LegacyMessage.VE, self.handle_version)
        channel.register(qrllegacy_pb2.LegacyMessage.PL, self.handle_peer_list)
        channel.register(qrllegacy_pb2.LegacyMessage.CHAINSTATE, self.handle_chain_state)
        channel.register(qrllegacy_pb2.LegacyMessage.SYNC, self.handle_sync)
        channel.register(qrllegacy_pb2.LegacyMessage.P2P_ACK, self.handle_p2p_acknowledgement)

    def _get_version_compatibility(self, version) -> bool:
        # Ignore compatibility test on Testnet
        if config.dev.hard_fork_heights == config.dev.testnet_hard_fork_heights:
            return True

        if self._p2p_factory is None:
            return True
        if self._p2p_factory.chain_height >= config.dev.hard_fork_heights[0]:
            try:
                major_version = version.split(".")[0]
                if int(major_version) < 2:
                    return False
            except Exception:
                # Disabled warning as it is not required and could be annoying
                # if a peer with dirty version is trying to connect with the node
                # logger.warning("Exception while checking version for compatibility")
                return True

        if self._p2p_factory.chain_height >= config.dev.hard_fork_heights[1]:
            try:
                major_version = version.split(".")[0]
                if int(major_version) < 3:
                    return False
            except Exception:
                # Disabled warning as it is not required and could be annoying
                # if a peer with dirty version is trying to connect with the node
                # logger.warning("Exception while checking version for compatibility")
                return True

        hard_fork_2 = config.dev.hard_fork_heights[2] + config.dev.hard_fork_node_disconnect_delay[2]
        if self._p2p_factory.chain_height >= hard_fork_2:
            try:
                major_version = version.split(".")[0]
                if int(major_version) < 4:
                    return False
            except Exception:
                # Disabled warning as it is not required and could be annoying
                # if a peer with dirty version is trying to connect with the node
                # logger.warning("Exception while checking version for compatibility")
                return True

        return True

    def handle_version(self, source, message: qrllegacy_pb2.LegacyMessage):
        """
        Version
        If version is empty, it sends the version & genesis_prev_headerhash.
        Otherwise, processes the content of data.
        In case of mismatches, it disconnects from the peer
        """
        self._validate_message(message, qrllegacy_pb2.LegacyMessage.VE)

        if not message.veData.version:
            msg = qrllegacy_pb2.LegacyMessage(
                func_name=qrllegacy_pb2.LegacyMessage.VE,
                veData=qrllegacy_pb2.VEData(version=config.dev.version,
                                            genesis_prev_hash=config.user.genesis_prev_headerhash,
                                            rate_limit=config.user.peer_rate_limit))

            source.send(msg)
            return

        logger.info('%s version: %s | genesis prev_headerhash %s',
                    source.peer.ip,
                    message.veData.version,
                    message.veData.genesis_prev_hash)

        if not self._get_version_compatibility(message.veData.version):
            logger.warning("Disconnecting from Peer %s running incompatible node version %s",
                           source.peer.ip,
                           message.veData.version)
            source.loseConnection()
            self.ban_channel(source)
            return

        source.rate_limit = min(config.user.peer_rate_limit, message.veData.rate_limit)

        if message.veData.genesis_prev_hash != config.user.genesis_prev_headerhash:
            logger.warning('%s genesis_prev_headerhash mismatch', source.peer)
            logger.warning('Expected: %s', config.user.genesis_prev_headerhash)
            logger.warning('Found: %s', message.veData.genesis_prev_hash)
            source.loseConnection()
            self.ban_channel(source)

    def handle_peer_list(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message, qrllegacy_pb2.LegacyMessage.PL)

        if not config.user.enable_peer_discovery:
            return

        if not message.plData.peer_ips:
            return

        # If public port is invalid, ignore rest of the data
        if not (0 < message.plData.public_port < 65536):
            return

        source.set_public_port(message.plData.public_port)

        self.insert_to_last_connected_peer(source.ip_public_port, True)

        sender_peer = IPMetadata(source.peer.ip, message.plData.public_port)

        # Check if peer list contains global ip, if it was sent by peer from a global ip address
        new_peers = self.combine_peer_lists(message.plData.peer_ips,
                                            [sender_peer.full_address],
                                            check_global=IPv4Address(source.peer.ip).is_global)

        logger.info('%s peers data received: %s', source.peer.ip, new_peers)
        if self._p2p_factory is not None:
            self._p2p_factory.add_new_peers_to_peer_q(new_peers)

    def handle_sync(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message, qrllegacy_pb2.LegacyMessage.SYNC)
        if message.syncData.state == '':
            if source.factory.synced:
                source.send_sync(synced=True)

    @staticmethod
    def send_node_chain_state(dest_channel, node_chain_state: qrl_pb2.NodeChainState):
        # FIXME: Not sure this belongs to peer management
        msg = qrllegacy_pb2.LegacyMessage(func_name=qrllegacy_pb2.LegacyMessage.CHAINSTATE,
                                          chainStateData=node_chain_state)
        dest_channel.send(msg)

    def monitor_chain_state(self):
        # FIXME: Not sure this belongs to peer management
        current_timestamp = ntp.getTime()
        for channel in self._channels:
            if channel not in self._peer_node_status:
                channel.loseConnection()
                continue
            delta = current_timestamp - self._peer_node_status[channel].timestamp
            if delta > config.user.chain_state_timeout:
                del self._peer_node_status[channel]
                logger.debug('>>>> No State Update [%18s] %2.2f (TIMEOUT)', channel.peer, delta)
                channel.loseConnection()

    def broadcast_chain_state(self, node_chain_state: qrl_pb2.NodeChainState):
        # FIXME: Not sure this belongs to peer management
        # TODO: Verify/Disconnect problematic channels
        # Ping all channels
        for channel in self._channels:
            self.send_node_chain_state(channel, node_chain_state)

        self._observable.notify(ObservableEvent(self.EventType.NO_PEERS))

    def handle_chain_state(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message, qrllegacy_pb2.LegacyMessage.CHAINSTATE)

        message.chainStateData.timestamp = ntp.getTime()  # Receiving time

        try:
            UInt256ToString(message.chainStateData.cumulative_difficulty)
        except ValueError:
            logger.warning('Invalid Cumulative Difficulty sent by peer')
            source.loseConnection()
            return

        self._peer_node_status[source] = message.chainStateData

        if not self._get_version_compatibility(message.chainStateData.version):
            logger.warning("Disconnecting from Peer %s running incompatible node version %s",
                           source.peer.ip,
                           message.veData.version)
            source.loseConnection()
            return

    def handle_p2p_acknowledgement(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message, qrllegacy_pb2.LegacyMessage.P2P_ACK)

        source.bytes_sent -= message.p2pAckData.bytes_processed
        if source.bytes_sent < 0:
            logger.warning('Disconnecting Peer %s', source.peer)
            logger.warning('Reason: negative bytes_sent value')
            logger.warning('bytes_sent %s', source.bytes_sent)
            logger.warning('Ack bytes_processed %s', message.p2pAckData.bytes_processed)
            source.loseConnection()

        source.send_next()

    ####################################################
    ####################################################
    ####################################################
    ####################################################
    def is_banned(self, peer: IPMetadata):
        return peer.ip in self._banned_peer_ips

    def ban_channel(self, channel: P2PProtocol):
        self._banned_peer_ips.add(channel.peer.ip)
        logger.warning('Banned %s', channel.peer.ip)
        channel.loseConnection()

    def get_peers_stat(self) -> list:
        peers_stat = []
        # Copying the list of keys, to avoid any change by other thread
        for source in list(self.peer_node_status.keys()):
            try:
                peer_stat = qrl_pb2.PeerStat(peer_ip=source.peer.ip.encode(),
                                             port=source.peer.port,
                                             node_chain_state=self.peer_node_status[source])
                peers_stat.append(peer_stat)
            except KeyError:
                # Ignore in case the key is deleted by other thread causing KeyError
                continue
        return peers_stat
예제 #3
0
class P2PPeerManager(P2PBaseObserver):
    class EventType(Enum):
        NO_PEERS = 1

    def __init__(self):
        super().__init__()
        self._ping_callLater = None
        self._disconnect_callLater = None
        self._channels = []

        self._peer_node_status = dict()

        self._known_peers = set()
        self.peers_path = os.path.join(config.user.data_dir,
                                       config.dev.peers_filename)

        self.banned_peers_filename = os.path.join(
            config.user.wallet_dir, config.dev.banned_peers_filename)
        self._banned_peer_ips = ExpiringSet(
            expiration_time=config.user.ban_minutes * 60,
            filename=self.banned_peers_filename)

        self._observable = Observable(self)
        self._p2pfactory = None

    def register(self, message_type: EventType, func: Callable):
        self._observable.register(message_type, func)

    @property
    def known_peer_addresses(self):
        return self._known_peers

    def trusted_peer(self, channel: P2PProtocol):
        if self.is_banned(channel.peer):
            return False

        if channel.valid_message_count < config.dev.trust_min_msgcount:
            return False

        if channel.connection_time < config.dev.trust_min_conntime:
            return False

        return True

    @property
    def trusted_addresses(self):
        return set([
            peer.peer.full_address for peer in self._p2pfactory.connections
            if self.trusted_peer(peer)
        ])

    @property
    def peer_node_status(self):
        return self._peer_node_status

    def load_known_peers(self) -> List[str]:
        known_peers = []
        try:
            logger.info('Loading known peers')
            with open(self.peers_path, 'r') as infile:
                known_peers = json.load(infile)
        except Exception as e:
            logger.info("Could not open known_peers list")

        return [IPMetadata.canonical_full_address(fa) for fa in known_peers]

    def save_known_peers(self, known_peers: List[str]):
        tmp = list(known_peers)
        config.create_path(config.user.data_dir)
        with open(self.peers_path, 'w') as outfile:
            json.dump(tmp, outfile)

    def load_peer_addresses(self) -> None:
        known_peers = self.load_known_peers()
        self._known_peers = self.combine_peer_lists(
            known_peers,
            config.user.peer_list,
        )
        logger.info('Loaded known peers: %s', self._known_peers)
        self.save_known_peers(self._known_peers)

    def extend_known_peers(self, new_peer_addresses: set) -> None:
        new_addresses = set(new_peer_addresses) - self._known_peers

        if self._p2pfactory is not None:
            for peer_address in new_addresses:
                self._p2pfactory.connect_peer(peer_address)

        self._known_peers |= set(new_peer_addresses)
        self.save_known_peers(list(self._known_peers))

    @staticmethod
    def combine_peer_lists(peer_ips,
                           sender_full_addresses: List,
                           check_global=False) -> Set[IPMetadata]:
        tmp_list = list(peer_ips)
        tmp_list.extend(sender_full_addresses)

        answer = set()
        for item in tmp_list:
            try:
                answer.add(
                    IPMetadata.canonical_full_address(item, check_global))
            except:  # noqa
                logger.warning("Invalid Peer Address {}".format(item))

        return answer

    def get_better_difficulty(self, current_cumulative_difficulty):
        best_cumulative_difficulty = int(
            UInt256ToString(current_cumulative_difficulty))
        local_best = best_cumulative_difficulty
        best_channel = None

        for channel in self._peer_node_status:
            node_chain_state = self._peer_node_status[channel]
            node_cumulative_difficulty = int(
                UInt256ToString(node_chain_state.cumulative_difficulty))
            if node_cumulative_difficulty > best_cumulative_difficulty:
                best_cumulative_difficulty = node_cumulative_difficulty
                best_channel = channel
        logger.debug('Local Best Diff : %s', local_best)
        logger.debug('Remote Best Diff : %s', best_cumulative_difficulty)
        return best_channel

    def remove_channel(self, channel):
        if channel in self._channels:
            self._channels.remove(channel)
        if channel in self._peer_node_status:
            del self._peer_node_status[channel]

    def new_channel(self, channel):
        self._channels.append(channel)
        self._peer_node_status[channel] = qrl_pb2.NodeChainState(
            block_number=0,
            header_hash=b'',
            cumulative_difficulty=b'\x00' * 32,
            timestamp=ntp.getTime())
        channel.register(qrllegacy_pb2.LegacyMessage.VE, self.handle_version)
        channel.register(qrllegacy_pb2.LegacyMessage.PL, self.handle_peer_list)
        channel.register(qrllegacy_pb2.LegacyMessage.CHAINSTATE,
                         self.handle_chain_state)
        channel.register(qrllegacy_pb2.LegacyMessage.SYNC, self.handle_sync)
        channel.register(qrllegacy_pb2.LegacyMessage.P2P_ACK,
                         self.handle_p2p_acknowledgement)

    def handle_version(self, source, message: qrllegacy_pb2.LegacyMessage):
        """
        Version
        If version is empty, it sends the version & genesis_prev_headerhash.
        Otherwise, processes the content of data.
        In case of mismatches, it disconnects from the peer
        """
        self._validate_message(message, qrllegacy_pb2.LegacyMessage.VE)

        if not message.veData.version:
            msg = qrllegacy_pb2.LegacyMessage(
                func_name=qrllegacy_pb2.LegacyMessage.VE,
                veData=qrllegacy_pb2.VEData(
                    version=config.dev.version,
                    genesis_prev_hash=config.user.genesis_prev_headerhash,
                    rate_limit=config.user.peer_rate_limit))

            source.send(msg)
            return

        logger.info('%s version: %s | genesis prev_headerhash %s',
                    source.peer.ip, message.veData.version,
                    message.veData.genesis_prev_hash)

        source.rate_limit = min(config.user.peer_rate_limit,
                                message.veData.rate_limit)

        if message.veData.genesis_prev_hash != config.user.genesis_prev_headerhash:
            logger.warning('%s genesis_prev_headerhash mismatch', source.peer)
            logger.warning('Expected: %s', config.user.genesis_prev_headerhash)
            logger.warning('Found: %s', message.veData.genesis_prev_hash)
            source.loseConnection()

    def handle_peer_list(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message,
                                          qrllegacy_pb2.LegacyMessage.PL)

        if not config.user.enable_peer_discovery:
            return

        if not message.plData.peer_ips:
            return

        sender_peer = IPMetadata(source.peer.ip, message.plData.public_port)

        new_peers = self.combine_peer_lists(message.plData.peer_ips,
                                            [sender_peer.full_address],
                                            check_global=True)
        new_peers.discard(source.host.full_address)  # Remove local address

        logger.info('%s peers data received: %s', source.peer.ip, new_peers)
        self.extend_known_peers(new_peers)

    def handle_sync(self, source, message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message,
                                          qrllegacy_pb2.LegacyMessage.SYNC)
        if message.syncData.state == '':
            if source.factory.synced:
                source.send_sync(synced=True)

    @staticmethod
    def send_node_chain_state(dest_channel,
                              node_chain_state: qrl_pb2.NodeChainState):
        # FIXME: Not sure this belongs to peer management
        msg = qrllegacy_pb2.LegacyMessage(
            func_name=qrllegacy_pb2.LegacyMessage.CHAINSTATE,
            chainStateData=node_chain_state)
        dest_channel.send(msg)

    def monitor_chain_state(self):
        # FIXME: Not sure this belongs to peer management
        current_timestamp = ntp.getTime()
        for channel in self._channels:
            if channel not in self._peer_node_status:
                channel.loseConnection()
                continue
            delta = current_timestamp - self._peer_node_status[
                channel].timestamp
            if delta > config.user.chain_state_timeout:
                del self._peer_node_status[channel]
                logger.debug('>>>> No State Update [%18s] %2.2f (TIMEOUT)',
                             channel.peer, delta)
                channel.loseConnection()

    def broadcast_chain_state(self, node_chain_state: qrl_pb2.NodeChainState):
        # FIXME: Not sure this belongs to peer management
        # TODO: Verify/Disconnect problematic channels
        # Ping all channels
        for channel in self._channels:
            self.send_node_chain_state(channel, node_chain_state)

        self._observable.notify(ObservableEvent(self.EventType.NO_PEERS))

    def handle_chain_state(self, source, message: qrllegacy_pb2.LegacyMessage):
        # FIXME: Not sure this belongs to peer management
        P2PBaseObserver._validate_message(
            message, qrllegacy_pb2.LegacyMessage.CHAINSTATE)

        message.chainStateData.timestamp = ntp.getTime()  # Receiving time

        try:
            UInt256ToString(message.chainStateData.cumulative_difficulty)
        except ValueError:
            logger.warning('Invalid Cumulative Difficulty sent by peer')
            source.loseConnection()
            return

        self._peer_node_status[source] = message.chainStateData

    def handle_p2p_acknowledgement(self, source,
                                   message: qrllegacy_pb2.LegacyMessage):
        P2PBaseObserver._validate_message(message,
                                          qrllegacy_pb2.LegacyMessage.P2P_ACK)

        source.bytes_sent -= message.p2pAckData.bytes_processed
        if source.bytes_sent < 0:
            logger.warning('Disconnecting Peer %s', source.peer)
            logger.warning('Reason: negative bytes_sent value')
            logger.warning('bytes_sent %s', source.bytes_sent)
            logger.warning('Ack bytes_processed %s',
                           message.p2pAckData.bytes_processed)
            source.loseConnection()

        source.send_next()

    ####################################################
    ####################################################
    ####################################################
    ####################################################
    def is_banned(self, peer: IPMetadata):
        return peer.ip in self._banned_peer_ips

    def ban_channel(self, channel: P2PProtocol):
        self._banned_peer_ips.add(channel.peer.ip)
        logger.warning('Banned %s', channel.peer.ip)
        channel.loseConnection()

    def connect_peers(self):
        logger.info('<<<Reconnecting to peer list: %s',
                    self.known_peer_addresses)
        for peer_address in self.known_peer_addresses:
            if self.is_banned(IPMetadata.from_full_address(peer_address)):
                continue
            self._p2pfactory.connect_peer(peer_address)

    def get_peers_stat(self) -> list:
        peers_stat = []
        # Copying the list of keys, to avoid any change by other thread
        for source in list(self.peer_node_status.keys()):
            try:
                peer_stat = qrl_pb2.PeerStat(
                    peer_ip=source.peer.ip.encode(),
                    port=source.peer.port,
                    node_chain_state=self.peer_node_status[source])
                peers_stat.append(peer_stat)
            except KeyError:
                # Ignore in case the key is deleted by other thread causing KeyError
                continue
        return peers_stat