コード例 #1
0
class TxnVersionController(TVController):

    def __init__(self) -> None:
        self._versions = SortedDict()
        self._f = 0
        self._votes_for_new_version = SortedDict()

    @property
    def version(self):
        return self._versions.peekitem(-1)[1] if self._versions else None

    def get_pool_version(self, timestamp):
        if timestamp is None:
            return self.version
        last_version = None
        for upgrade_tm, version in self._versions.items():
            if timestamp < upgrade_tm:
                return last_version
            last_version = version
        return last_version

    def update_version(self, txn):
        if get_type(txn) == POOL_UPGRADE and get_payload_data(txn).get(ACTION) == START:
            N = len(get_payload_data(txn).get(SCHEDULE, {}))
            self._f = (N - 1) // 3
        elif get_type(txn) == NODE_UPGRADE and get_payload_data(txn)[DATA][ACTION] == COMPLETE:
            version = get_payload_data(txn)[DATA][VERSION]
            self._votes_for_new_version.setdefault(version, set())
            self._votes_for_new_version[version].add(get_from(txn))
            if len(self._votes_for_new_version[version]) > self._f:
                self._versions[get_txn_time(txn)] = version
                self._votes_for_new_version = SortedDict({v: senders
                                                          for v, senders in self._votes_for_new_version.items()
                                                          if v > version})
コード例 #2
0
class TxnVersionController(ITxnVersionController):
    def __init__(self) -> None:
        self._version = None
        self._f = 0
        self._votes_for_new_version = SortedDict()

    @property
    def version(self):
        return self._version

    def update_version(self, txn):
        if get_type(txn) == POOL_UPGRADE and get_payload_data(txn).get(
                ACTION) == START:
            N = len(get_payload_data(txn).get(SCHEDULE, {}))
            self._f = (N - 1) // 3
        elif get_type(txn) == NODE_UPGRADE and get_payload_data(
                txn)[DATA][ACTION] == COMPLETE:
            version = get_payload_data(txn)[DATA][VERSION]
            self._votes_for_new_version.setdefault(version, set())
            self._votes_for_new_version[version].add(get_from(txn))
            if len(self._votes_for_new_version[version]) > self._f:
                self._version = version
                self._votes_for_new_version = SortedDict({
                    v: senders
                    for v, senders in self._votes_for_new_version.items()
                    if v > version
                })
コード例 #3
0
class Replicas:
    _replica_class = Replica

    def __init__(self,
                 node,
                 monitor: Monitor,
                 config=None,
                 metrics: MetricsCollector = NullMetricsCollector()):
        # passing full node because Replica requires it
        self._node = node
        self._monitor = monitor
        self._metrics = metrics
        self._config = config
        self._replicas = SortedDict()  # type: SortedDict[int, Replica]
        self._messages_to_replicas = dict()  # type: Dict[deque]
        self.register_monitor_handler()

    def add_replica(self, instance_id) -> int:
        is_master = instance_id == 0
        description = "master" if is_master else "backup"
        bls_bft = self._create_bls_bft_replica(is_master)
        replica = self._new_replica(instance_id, is_master, bls_bft)
        replica.set_view_no(
            self._node.viewNo if self._node.viewNo is not None else 0)
        self._replicas[instance_id] = replica
        self._messages_to_replicas[instance_id] = deque()
        self._monitor.addInstance(instance_id)

        logger.display("{} added replica {} to instance {} ({})".format(
            self._node.name, replica, instance_id, description),
                       extra={"tags": ["node-replica"]})

        logger.info('reset monitor due to replica addition')
        self._monitor.reset()

    def remove_replica(self, inst_id: int):
        if inst_id not in self._replicas:
            return
        replica = self._replicas.pop(inst_id)
        replica.cleanup()

        self._messages_to_replicas.pop(inst_id, None)
        self._monitor.removeInstance(inst_id)
        logger.display("{} removed replica {} from instance {}".format(
            self._node.name, replica, replica.instId),
                       extra={"tags": ["node-replica"]})

    def send_to_internal_bus(self, msg, inst_id: int = None):
        if inst_id is None:
            for replica in self._replicas.values():
                replica.internal_bus.send(msg)
        else:
            if inst_id in self._replicas:
                self._replicas[inst_id].internal_bus.send(msg)
            else:
                logger.info("Cannot send msg ({}) to the replica {} "
                            "because it does not exist.".format(msg, inst_id))

    def subscribe_to_internal_bus(self,
                                  message_type: Type,
                                  handler: Callable,
                                  inst_id: int = None):
        if inst_id is None:
            for replica in self._replicas.values():
                replica.internal_bus.subscribe(message_type, handler)
        else:
            if inst_id in self._replicas:
                self._replicas[inst_id].internal_bus.subscribe(
                    message_type, handler)
            else:
                logger.info("Cannot subscribe for {} for the replica {} "
                            "because it does not exist.".format(
                                message_type, inst_id))

    # TODO unit test
    @property
    def some_replica_is_primary(self) -> bool:
        return any([r.isPrimary for r in self._replicas.values()])

    @property
    def master_replica_is_primary(self):
        if self.num_replicas > 0:
            return self._master_replica.isPrimary

    @property
    def _master_replica(self):
        return self._replicas[MASTER_REPLICA_INDEX]

    def service_inboxes(self, limit: int = None):
        number_of_processed_messages = \
            sum(replica.serviceQueues(limit) for replica in self._replicas.values())
        return number_of_processed_messages

    def pass_message(self, message, instance_id=None):
        if instance_id is not None:
            if instance_id not in self._replicas.keys():
                return
            self._replicas[instance_id].inBox.append(message)
        else:
            for replica in self._replicas.values():
                replica.inBox.append(message)

    def get_output(self, limit: int = None) -> Generator:
        if limit is None:
            per_replica = None
        else:
            per_replica = round(limit / self.num_replicas)
            if per_replica == 0:
                logger.debug("{} forcibly setting replica "
                             "message limit to {}".format(
                                 self._node.name, per_replica))
                per_replica = 1
        for replica in list(self._replicas.values()):
            num = 0
            while replica.outBox:
                yield replica.outBox.popleft()
                num += 1
                if per_replica and num >= per_replica:
                    break

    def take_ordereds_out_of_turn(self) -> tuple:
        """
        Takes all Ordered messages from outbox out of turn
        """
        for replica in self._replicas.values():
            yield replica.instId, replica._remove_ordered_from_queue()

    def _new_replica(self, instance_id: int, is_master: bool,
                     bls_bft: BlsBft) -> Replica:
        """
        Create a new replica with the specified parameters.
        """
        return self._replica_class(self._node, instance_id, self._config,
                                   is_master, bls_bft, self._metrics)

    def _create_bls_bft_replica(self, is_master):
        bls_factory = create_default_bls_bft_factory(self._node)
        bls_bft_replica = bls_factory.create_bls_bft_replica(is_master)
        return bls_bft_replica

    @property
    def num_replicas(self):
        return len(self._replicas)

    @property
    def sum_inbox_len(self):
        return sum(len(replica.inBox) for replica in self._replicas.values())

    @property
    def all_instances_have_primary(self) -> bool:
        return all(replica.primaryName is not None
                   for replica in self._replicas.values())

    @property
    def primary_name_by_inst_id(self) -> dict:
        return {
            r.instId: replica_name_to_node_name(r.primaryName)
            for r in self._replicas.values()
        }

    @property
    def inst_id_by_primary_name(self) -> dict:
        return {
            replica_name_to_node_name(r.primaryName): r.instId
            for r in self._replicas.values() if r.primaryName
        }

    def register_new_ledger(self, ledger_id):
        for replica in self._replicas.values():
            replica.register_ledger(ledger_id)

    def register_monitor_handler(self):
        # attention: handlers will work over unordered request only once
        self._monitor.unordered_requests_handlers.append(
            self.unordered_request_handler_logging)

    def unordered_request_handler_logging(self, unordereds):
        replica = self._master_replica
        for unordered in unordereds:
            reqId, duration = unordered

            # get ppSeqNo and viewNo
            preprepares = replica._ordering_service.sent_preprepares if replica.isPrimary else replica._ordering_service.prePrepares
            ppSeqNo = None
            viewNo = None
            for key in preprepares:
                if any([
                        pre_pre_req == reqId
                        for pre_pre_req in preprepares[key].reqIdr
                ]):
                    ppSeqNo = preprepares[key].ppSeqNo
                    viewNo = preprepares[key].viewNo
                    break
            if ppSeqNo is None or viewNo is None:
                logger.warning(
                    'Unordered request with reqId: {} was not found in prePrepares. '
                    'Prepares count: {}, Commits count: {}'.format(
                        reqId, len(replica._ordering_service.prepares),
                        len(replica._ordering_service.commits)))
                continue

            # get pre-prepare sender
            prepre_sender = replica.primaryNames.get(viewNo, 'UNKNOWN')

            # get prepares info
            prepares = replica._ordering_service.prepares[(viewNo, ppSeqNo)][0] \
                if (viewNo, ppSeqNo) in replica._ordering_service.prepares else []
            n_prepares = len(prepares)
            str_prepares = 'noone'
            if n_prepares:
                str_prepares = ', '.join(prepares)

            # get commits info
            commits = replica._ordering_service.commits[(viewNo, ppSeqNo)][0] \
                if (viewNo, ppSeqNo) in replica._ordering_service.commits else []
            n_commits = len(commits)
            str_commits = 'noone'
            if n_commits:
                str_commits = ', '.join(commits)

            # get txn content
            content = replica.requests[reqId].request.as_dict \
                if reqId in replica.requests else 'no content saved'

            logger.warning(
                'Consensus for digest {} was not achieved within {} seconds. '
                'Primary node is {}. '
                'Received Pre-Prepare from {}. '
                'Received {} valid Prepares from {}. '
                'Received {} valid Commits from {}. '
                'Transaction contents: {}. '.format(
                    reqId, duration,
                    replica_name_to_node_name(replica.primaryName),
                    prepre_sender, n_prepares, str_prepares, n_commits,
                    str_commits, content))

    def keys(self):
        return self._replicas.keys()

    def values(self):
        return self._replicas.values()

    def items(self):
        return self._replicas.items()

    def __getitem__(self, item):
        if not isinstance(item, int):
            raise PlenumTypeError('item', item, int)
        return self._replicas[item]

    def __len__(self):
        return self.num_replicas

    def __iter__(self):
        return self._replicas.__iter__()
コード例 #4
0
ファイル: replicas.py プロジェクト: michaeldboyd/indy-plenum
class Replicas:
    _replica_class = Replica

    def __init__(self, node, monitor: Monitor, config=None, metrics: MetricsCollector = NullMetricsCollector()):
        # passing full node because Replica requires it
        self._node = node
        self._monitor = monitor
        self._metrics = metrics
        self._config = config
        self._replicas = SortedDict()  # type: SortedDict[int, Replica]
        self._messages_to_replicas = dict()  # type: Dict[deque]
        self.register_monitor_handler()

    def add_replica(self, instance_id) -> int:
        is_master = instance_id == 0
        description = "master" if is_master else "backup"
        bls_bft = self._create_bls_bft_replica(is_master)
        replica = self._new_replica(instance_id, is_master, bls_bft)
        self._replicas[instance_id] = replica
        self._messages_to_replicas[instance_id] = deque()
        self._monitor.addInstance(instance_id)

        logger.display("{} added replica {} to instance {} ({})"
                       .format(self._node.name,
                               replica,
                               instance_id,
                               description),
                       extra={"tags": ["node-replica"]})

    def remove_replica(self, inst_id: int):
        if inst_id not in self._replicas:
            return
        replica = self._replicas.pop(inst_id)

        # Aggregate all the currently forwarded requests
        req_keys = set()
        for msg in replica.inBox:
            if isinstance(msg, ReqKey):
                req_keys.add(msg.digest)
        for req_queue in replica.requestQueues.values():
            for req_key in req_queue:
                req_keys.add(req_key)
        for pp in replica.sentPrePrepares.values():
            for req_key in pp.reqIdr:
                req_keys.add(req_key)
        for pp in replica.prePrepares.values():
            for req_key in pp.reqIdr:
                req_keys.add(req_key)

        for req_key in req_keys:
            if req_key in replica.requests:
                replica.requests.free(req_key)

        self._messages_to_replicas.pop(inst_id, None)
        self._monitor.removeInstance(inst_id)
        logger.display("{} removed replica {} from instance {}".
                       format(self._node.name, replica, replica.instId),
                       extra={"tags": ["node-replica"]})

    # TODO unit test
    @property
    def some_replica_is_primary(self) -> bool:
        return any([r.isPrimary for r in self._replicas.values()])

    @property
    def master_replica_is_primary(self):
        if self.num_replicas > 0:
            return self._master_replica.isPrimary

    @property
    def _master_replica(self):
        return self._replicas[MASTER_REPLICA_INDEX]

    def service_inboxes(self, limit: int = None):
        number_of_processed_messages = \
            sum(replica.serviceQueues(limit) for replica in self._replicas.values())
        return number_of_processed_messages

    def pass_message(self, message, instance_id=None):
        if instance_id is not None:
            if instance_id not in self._replicas.keys():
                return
            self._replicas[instance_id].inBox.append(message)
        else:
            for replica in self._replicas.values():
                replica.inBox.append(message)

    def get_output(self, limit: int = None) -> Generator:
        if limit is None:
            per_replica = None
        else:
            per_replica = round(limit / self.num_replicas)
            if per_replica == 0:
                logger.debug("{} forcibly setting replica "
                             "message limit to {}"
                             .format(self._node.name,
                                     per_replica))
                per_replica = 1
        for replica in list(self._replicas.values()):
            num = 0
            while replica.outBox:
                yield replica.outBox.popleft()
                num += 1
                if per_replica and num >= per_replica:
                    break

    def take_ordereds_out_of_turn(self) -> tuple:
        """
        Takes all Ordered messages from outbox out of turn
        """
        for replica in self._replicas.values():
            yield replica.instId, replica._remove_ordered_from_queue()

    def _new_replica(self, instance_id: int, is_master: bool, bls_bft: BlsBft) -> Replica:
        """
        Create a new replica with the specified parameters.
        """
        return self._replica_class(self._node, instance_id, self._config, is_master, bls_bft, self._metrics)

    def _create_bls_bft_replica(self, is_master):
        bls_factory = create_default_bls_bft_factory(self._node)
        bls_bft_replica = bls_factory.create_bls_bft_replica(is_master)
        return bls_bft_replica

    @property
    def num_replicas(self):
        return len(self._replicas)

    @property
    def sum_inbox_len(self):
        return sum(len(replica.inBox) for replica in self._replicas.values())

    @property
    def all_instances_have_primary(self) -> bool:
        return all(replica.primaryName is not None
                   for replica in self._replicas.values())

    @property
    def primary_name_by_inst_id(self) -> dict:
        return {r.instId: r.primaryName.split(":", maxsplit=1)[0] if r.primaryName else None
                for r in self._replicas.values()}

    @property
    def inst_id_by_primary_name(self) -> dict:
        return {r.primaryName.split(":", maxsplit=1)[0]: r.instId
                for r in self._replicas.values() if r.primaryName}

    def register_new_ledger(self, ledger_id):
        for replica in self._replicas.values():
            replica.register_ledger(ledger_id)

    def register_monitor_handler(self):
        # attention: handlers will work over unordered request only once
        self._monitor.unordered_requests_handlers.append(
            self.unordered_request_handler_logging)

    def unordered_request_handler_logging(self, unordereds):
        replica = self._master_replica
        for unordered in unordereds:
            reqId, duration = unordered

            # get ppSeqNo and viewNo
            preprepares = replica.sentPrePrepares if replica.isPrimary else replica.prePrepares
            ppSeqNo = None
            viewNo = None
            for key in preprepares:
                if any([pre_pre_req == reqId for pre_pre_req in preprepares[key].reqIdr]):
                    ppSeqNo = preprepares[key].ppSeqNo
                    viewNo = preprepares[key].viewNo
                    break
            if ppSeqNo is None or viewNo is None:
                logger.warning('Unordered request with reqId: {} was not found in prePrepares. '
                               'Prepares count: {}, Commits count: {}'.format(reqId,
                                                                              len(replica.prepares),
                                                                              len(replica.commits)))
                continue

            # get pre-prepare sender
            prepre_sender = replica.primaryNames[viewNo]

            # get prepares info
            prepares = replica.prepares[(viewNo, ppSeqNo)][0] \
                if (viewNo, ppSeqNo) in replica.prepares else []
            n_prepares = len(prepares)
            str_prepares = 'noone'
            if n_prepares:
                str_prepares = ', '.join(prepares)

            # get commits info
            commits = replica.commits[(viewNo, ppSeqNo)][0] \
                if (viewNo, ppSeqNo) in replica.commits else []
            n_commits = len(commits)
            str_commits = 'noone'
            if n_commits:
                str_commits = ', '.join(commits)

            # get txn content
            content = replica.requests[reqId].finalised.as_dict \
                if reqId in replica.requests else 'no content saved'

            logger.warning('Consensus for digest {} was not achieved within {} seconds. '
                           'Primary node is {}. '
                           'Received Pre-Prepare from {}. '
                           'Received {} valid Prepares from {}. '
                           'Received {} valid Commits from {}. '
                           'Transaction contents: {}. '
                           .format(reqId, duration, replica.primaryName.split(':')[0], prepre_sender,
                                   n_prepares, str_prepares, n_commits, str_commits, content))

    def keys(self):
        return self._replicas.keys()

    def values(self):
        return self._replicas.values()

    def items(self):
        return self._replicas.items()

    def __getitem__(self, item):
        if not isinstance(item, int):
            raise PlenumTypeError('item', item, int)
        return self._replicas[item]

    def __len__(self):
        return self.num_replicas

    def __iter__(self):
        return self._replicas.__iter__()
コード例 #5
0
class CheckpointService:
    STASHED_CHECKPOINTS_BEFORE_CATCHUP = 1

    def __init__(
            self,
            data: ConsensusSharedData,
            bus: InternalBus,
            network: ExternalBus,
            stasher: StashingRouter,
            db_manager: DatabaseManager,
            metrics: MetricsCollector = NullMetricsCollector(),
    ):
        self._data = data
        self._bus = bus
        self._network = network
        self._checkpoint_state = SortedDict(lambda k: k[1])
        self._stasher = stasher
        self._subscription = Subscription()
        self._validator = CheckpointMsgValidator(self._data)
        self._db_manager = db_manager
        self.metrics = metrics

        # Stashed checkpoints for each view. The key of the outermost
        # dictionary is the view_no, value being a dictionary with key as the
        # range of the checkpoint and its value again being a mapping between
        # senders and their sent checkpoint
        # Dict[view_no, Dict[(seqNoStart, seqNoEnd),  Dict[sender, Checkpoint]]]
        self._stashed_recvd_checkpoints = {}

        self._config = getConfig()
        self._logger = getlogger()

        self._subscription.subscribe(stasher, Checkpoint,
                                     self.process_checkpoint)

        self._subscription.subscribe(bus, Ordered, self.process_ordered)
        self._subscription.subscribe(bus, BackupSetupLastOrdered,
                                     self.process_backup_setup_last_ordered)
        self._subscription.subscribe(bus, NewViewAccepted,
                                     self.process_new_view_accepted)

    def cleanup(self):
        self._subscription.unsubscribe_all()

    @property
    def view_no(self):
        return self._data.view_no

    @property
    def is_master(self):
        return self._data.is_master

    @property
    def last_ordered_3pc(self):
        return self._data.last_ordered_3pc

    @measure_consensus_time(MetricsName.PROCESS_CHECKPOINT_TIME,
                            MetricsName.BACKUP_PROCESS_CHECKPOINT_TIME)
    def process_checkpoint(self, msg: Checkpoint, sender: str) -> (bool, str):
        """
        Process checkpoint messages
        :return: whether processed (True) or stashed (False)
        """
        if msg.instId != self._data.inst_id:
            return None, None
        self._logger.info('{} processing checkpoint {} from {}'.format(
            self, msg, sender))
        result, reason = self._validator.validate(msg)
        if result == PROCESS:
            self._do_process_checkpoint(msg, sender)
        return result, reason

    def _do_process_checkpoint(self, msg: Checkpoint, sender: str) -> bool:
        """
        Process checkpoint messages

        :return: whether processed (True) or stashed (False)
        """
        seqNoEnd = msg.seqNoEnd
        seqNoStart = msg.seqNoStart
        key = (seqNoStart, seqNoEnd)

        if key not in self._checkpoint_state or not self._checkpoint_state[
                key].digest:
            self._stash_checkpoint(msg, sender)
            self._remove_stashed_checkpoints(self.last_ordered_3pc)
            self._start_catchup_if_needed()
            return False

        checkpoint_state = self._checkpoint_state[key]
        # Raise the error only if master since only master's last
        # ordered 3PC is communicated during view change
        if self.is_master and checkpoint_state.digest != msg.digest:
            self._logger.warning("{} received an incorrect digest {} for "
                                 "checkpoint {} from {}".format(
                                     self, msg.digest, key, sender))
            return True

        checkpoint_state.receivedDigests[sender] = msg.digest
        self._check_if_checkpoint_stable(key)
        return True

    def process_backup_setup_last_ordered(self, msg: BackupSetupLastOrdered):
        if msg.inst_id != self._data.inst_id:
            return
        self.update_watermark_from_3pc()

    def process_ordered(self, ordered: Ordered):
        if ordered.instId != self._data.inst_id:
            return
        for batch_id in reversed(self._data.preprepared):
            if batch_id.pp_seq_no == ordered.ppSeqNo:
                self._add_to_checkpoint(batch_id.pp_seq_no, batch_id.pp_digest,
                                        ordered.ledgerId, batch_id.view_no,
                                        ordered.auditTxnRootHash)
                return
        raise LogicError(
            "CheckpointService | Can't process Ordered msg because "
            "ppSeqNo {} not in preprepared".format(ordered.ppSeqNo))

    def _start_catchup_if_needed(self):
        stashed_checkpoint_ends = self._stashed_checkpoints_with_quorum()
        lag_in_checkpoints = len(stashed_checkpoint_ends)
        if self._checkpoint_state:
            (s, e) = firstKey(self._checkpoint_state)
            # If the first stored own checkpoint has a not aligned lower bound
            # (this means that it was started after a catch-up), is complete
            # and there is a quorumed stashed checkpoint from other replicas
            # with the same end then don't include this stashed checkpoint
            # into the lag
            if s % self._config.CHK_FREQ != 0 \
                    and self._checkpoint_state[(s, e)].seqNo == e \
                    and e in stashed_checkpoint_ends:
                lag_in_checkpoints -= 1
        is_stashed_enough = \
            lag_in_checkpoints > self.STASHED_CHECKPOINTS_BEFORE_CATCHUP
        if not is_stashed_enough:
            return

        if self.is_master:
            self._logger.display(
                '{} has lagged for {} checkpoints so updating watermarks to {}'
                .format(self, lag_in_checkpoints, stashed_checkpoint_ends[-1]))
            self.set_watermarks(low_watermark=stashed_checkpoint_ends[-1])
            if not self._data.is_primary:
                self._logger.display(
                    '{} has lagged for {} checkpoints so the catchup procedure starts'
                    .format(self, lag_in_checkpoints))
                self._bus.send(NeedMasterCatchup())
        else:
            self._logger.info(
                '{} has lagged for {} checkpoints so adjust last_ordered_3pc to {}, '
                'shift watermarks and clean collections'.format(
                    self, lag_in_checkpoints, stashed_checkpoint_ends[-1]))
            # Adjust last_ordered_3pc, shift watermarks, clean operational
            # collections and process stashed messages which now fit between
            # watermarks
            key_3pc = (self.view_no, stashed_checkpoint_ends[-1])
            self._bus.send(
                NeedBackupCatchup(inst_id=self._data.inst_id,
                                  caught_up_till_3pc=key_3pc))
            self.caught_up_till_3pc(key_3pc)

    def gc_before_new_view(self):
        self._reset_checkpoints()
        self._remove_stashed_checkpoints(till_3pc_key=(self.view_no, 0))

    def caught_up_till_3pc(self, caught_up_till_3pc):
        self._reset_checkpoints()
        self._remove_stashed_checkpoints(till_3pc_key=caught_up_till_3pc)
        self.update_watermark_from_3pc()

    def catchup_clear_for_backup(self):
        self._reset_checkpoints()
        self._remove_stashed_checkpoints()
        self.set_watermarks(low_watermark=0, high_watermark=sys.maxsize)

    def _add_to_checkpoint(self, ppSeqNo, digest, ledger_id, view_no,
                           audit_txn_root_hash):
        for (s, e) in self._checkpoint_state.keys():
            if s <= ppSeqNo <= e:
                state = self._checkpoint_state[s, e]  # type: CheckpointState
                state.digests.append(digest)
                state = updateNamedTuple(state, seqNo=ppSeqNo)
                self._checkpoint_state[s, e] = state
                break
        else:
            s, e = ppSeqNo, math.ceil(
                ppSeqNo / self._config.CHK_FREQ) * self._config.CHK_FREQ
            self._logger.debug("{} adding new checkpoint state for {}".format(
                self, (s, e)))
            state = CheckpointState(ppSeqNo, [
                digest,
            ], None, {}, False)
            self._checkpoint_state[s, e] = state

        if state.seqNo == e:
            if len(state.digests) == self._config.CHK_FREQ:
                self._do_checkpoint(state, s, e, ledger_id, view_no,
                                    audit_txn_root_hash)
            self._process_stashed_checkpoints((s, e), view_no)

    @measure_consensus_time(MetricsName.SEND_CHECKPOINT_TIME,
                            MetricsName.BACKUP_SEND_CHECKPOINT_TIME)
    def _do_checkpoint(self, state, s, e, ledger_id, view_no,
                       audit_txn_root_hash):
        # TODO CheckpointState/Checkpoint is not a namedtuple anymore
        # 1. check if updateNamedTuple works for the new message type
        # 2. choose another name

        # TODO: This is hack of hacks, should be removed when refactoring is complete
        if not self.is_master and audit_txn_root_hash is None:
            audit_txn_root_hash = "7RJ5bkAKRy2CCvarRij2jiHC16SVPjHcrpVdNsboiQGv"

        state = updateNamedTuple(state, digest=audit_txn_root_hash, digests=[])
        self._checkpoint_state[s, e] = state
        self._logger.info(
            "{} sending Checkpoint {} view {} checkpointState digest {}. Ledger {} "
            "txn root hash {}. Committed state root hash {} Uncommitted state root hash {}"
            .format(
                self, (s, e), view_no, state.digest, ledger_id,
                self._db_manager.get_txn_root_hash(ledger_id),
                self._db_manager.get_state_root_hash(ledger_id,
                                                     committed=True),
                self._db_manager.get_state_root_hash(ledger_id,
                                                     committed=False)))
        checkpoint = Checkpoint(self._data.inst_id, view_no, s, e,
                                state.digest)
        self._network.send(checkpoint)
        self._data.checkpoints.append(checkpoint)

    def _mark_checkpoint_stable(self, seqNo):
        previousCheckpoints = []
        for (s, e), state in self._checkpoint_state.items():
            if e == seqNo:
                # TODO CheckpointState/Checkpoint is not a namedtuple anymore
                # 1. check if updateNamedTuple works for the new message type
                # 2. choose another name
                state = updateNamedTuple(state, isStable=True)
                self._checkpoint_state[s, e] = state
                self._set_stable_checkpoint(e)
                break
            else:
                previousCheckpoints.append((s, e))
        else:
            self._logger.debug("{} could not find {} in checkpoints".format(
                self, seqNo))
            return
        self.set_watermarks(low_watermark=seqNo)
        for k in previousCheckpoints:
            self._logger.trace("{} removing previous checkpoint {}".format(
                self, k))
            self._checkpoint_state.pop(k)
        self._remove_stashed_checkpoints(till_3pc_key=(self.view_no, seqNo))
        self._bus.send(
            CheckpointStabilized(self._data.inst_id, (self.view_no, seqNo)))
        self._logger.info("{} marked stable checkpoint {}".format(
            self, (s, e)))

    def _check_if_checkpoint_stable(self, key: Tuple[int, int]):
        ckState = self._checkpoint_state[key]
        if self._data.quorums.checkpoint.is_reached(
                len(ckState.receivedDigests)):
            self._mark_checkpoint_stable(ckState.seqNo)
            return True
        else:
            self._logger.debug('{} has state.receivedDigests as {}'.format(
                self, ckState.receivedDigests.keys()))
            return False

    def _stash_checkpoint(self, ck: Checkpoint, sender: str):
        self._logger.debug('{} stashing {} from {}'.format(self, ck, sender))
        seqNoStart, seqNoEnd = ck.seqNoStart, ck.seqNoEnd
        if ck.viewNo not in self._stashed_recvd_checkpoints:
            self._stashed_recvd_checkpoints[ck.viewNo] = {}
        stashed_for_view = self._stashed_recvd_checkpoints[ck.viewNo]
        if (seqNoStart, seqNoEnd) not in stashed_for_view:
            stashed_for_view[seqNoStart, seqNoEnd] = {}
        stashed_for_view[seqNoStart, seqNoEnd][sender] = ck

    def _stashed_checkpoints_with_quorum(self):
        end_pp_seq_numbers = []
        quorum = self._data.quorums.checkpoint
        for (_, seq_no_end), senders in self._stashed_recvd_checkpoints.get(
                self.view_no, {}).items():
            if quorum.is_reached(len(senders)):
                end_pp_seq_numbers.append(seq_no_end)
        return sorted(end_pp_seq_numbers)

    def _process_stashed_checkpoints(self, key, view_no):
        # Remove all checkpoints from previous views if any
        self._remove_stashed_checkpoints(till_3pc_key=(self.view_no, 0))

        if key not in self._stashed_recvd_checkpoints.get(view_no, {}):
            self._logger.trace("{} have no stashed checkpoints for {}")
            return

        # Get a snapshot of all the senders of stashed checkpoints for `key`
        senders = list(self._stashed_recvd_checkpoints[view_no][key].keys())
        total_processed = 0
        consumed = 0

        for sender in senders:
            # Check if the checkpoint from `sender` is still in
            # `stashed_recvd_checkpoints` because it might be removed from there
            # in case own checkpoint was stabilized when we were processing
            # stashed checkpoints from previous senders in this loop
            if view_no in self._stashed_recvd_checkpoints \
                    and key in self._stashed_recvd_checkpoints[view_no] \
                    and sender in self._stashed_recvd_checkpoints[view_no][key]:
                if self.process_checkpoint(
                        self._stashed_recvd_checkpoints[view_no][key].pop(
                            sender), sender):
                    consumed += 1
                # Note that if `process_checkpoint` returned False then the
                # checkpoint from `sender` was re-stashed back to
                # `stashed_recvd_checkpoints`
                total_processed += 1

        # If we have consumed stashed checkpoints for `key` from all the
        # senders then remove entries which have become empty
        if view_no in self._stashed_recvd_checkpoints \
                and key in self._stashed_recvd_checkpoints[view_no] \
                and len(self._stashed_recvd_checkpoints[view_no][key]) == 0:
            del self._stashed_recvd_checkpoints[view_no][key]
            if len(self._stashed_recvd_checkpoints[view_no]) == 0:
                del self._stashed_recvd_checkpoints[view_no]

        restashed = total_processed - consumed
        self._logger.info('{} processed {} stashed checkpoints for {}, '
                          '{} of them were stashed again'.format(
                              self, total_processed, key, restashed))

        return total_processed

    def reset_watermarks_before_new_view(self):
        # Reset any previous view watermarks since for view change to
        # successfully complete, the node must have reached the same state
        # as other nodes
        self.set_watermarks(low_watermark=0)

    def should_reset_watermarks_before_new_view(self):
        if self.view_no <= 0:
            return False
        if self.last_ordered_3pc[
                0] == self.view_no and self.last_ordered_3pc[1] > 0:
            return False
        return True

    def set_watermarks(self, low_watermark: int, high_watermark: int = None):
        self._data.low_watermark = low_watermark
        self._data.high_watermark = self._data.low_watermark + self._config.LOG_SIZE \
            if high_watermark is None else \
            high_watermark

        self._logger.info('{} set watermarks as {} {}'.format(
            self, self._data.low_watermark, self._data.high_watermark))
        self._stasher.process_all_stashed(STASH_WATERMARKS)

    def update_watermark_from_3pc(self):
        last_ordered_3pc = self.last_ordered_3pc
        if (last_ordered_3pc is not None) and (last_ordered_3pc[0]
                                               == self.view_no):
            self._logger.info(
                "update_watermark_from_3pc to {}".format(last_ordered_3pc))
            self.set_watermarks(last_ordered_3pc[1])
        else:
            self._logger.info(
                "try to update_watermark_from_3pc but last_ordered_3pc is None"
            )

    def _remove_stashed_checkpoints(self, till_3pc_key=None):
        """
        Remove stashed received checkpoints up to `till_3pc_key` if provided,
        otherwise remove all stashed received checkpoints
        """
        if till_3pc_key is None:
            self._stashed_recvd_checkpoints.clear()
            self._logger.info(
                '{} removing all stashed checkpoints'.format(self))
            return

        for view_no in list(self._stashed_recvd_checkpoints.keys()):

            if view_no < till_3pc_key[0]:
                self._logger.info(
                    '{} removing stashed checkpoints for view {}'.format(
                        self, view_no))
                del self._stashed_recvd_checkpoints[view_no]

            elif view_no == till_3pc_key[0]:
                for (s, e) in list(
                        self._stashed_recvd_checkpoints[view_no].keys()):
                    if e <= till_3pc_key[1]:
                        self._logger.info(
                            '{} removing stashed checkpoints: '
                            'viewNo={}, seqNoStart={}, seqNoEnd={}'.format(
                                self, view_no, s, e))
                        del self._stashed_recvd_checkpoints[view_no][(s, e)]
                if len(self._stashed_recvd_checkpoints[view_no]) == 0:
                    del self._stashed_recvd_checkpoints[view_no]

    def _reset_checkpoints(self):
        # That function most probably redundant in PBFT approach,
        # because according to paper, checkpoints cleared only when next stabilized.
        # Avoid using it while implement other services.
        self._checkpoint_state.clear()
        self._data.checkpoints.clear()
        # TODO: change to = 1 in ViewChangeService integration.
        self._data.stable_checkpoint = 0

    def _set_stable_checkpoint(self, end_seq_no):
        if not list(self._data.checkpoints.irange_key(end_seq_no, end_seq_no)):
            raise LogicError('Stable checkpoint must be in checkpoints')
        self._data.stable_checkpoint = end_seq_no

        self._data.checkpoints = \
            SortedListWithKey([c for c in self._data.checkpoints if c.seqNoEnd >= end_seq_no],
                              key=lambda checkpoint: checkpoint.seqNoEnd)

    def __str__(self) -> str:
        return "{} - checkpoint_service".format(self._data.name)

    # TODO: move to OrderingService as a handler for Cleanup messages
    # def _clear_batch_till_seq_no(self, seq_no):
    #     self._data.preprepared = [pp for pp in self._data.preprepared if pp.ppSeqNo >= seq_no]
    #     self._data.prepared = [p for p in self._data.prepared if p.ppSeqNo >= seq_no]

    def discard(self, msg, reason, sender):
        self._logger.trace("{} discard message {} from {} "
                           "with the reason: {}".format(
                               self, msg, sender, reason))

    def process_new_view_accepted(self, msg: NewViewAccepted):
        # 1. update shared data
        cp = msg.checkpoint
        if cp not in self._data.checkpoints:
            self._data.checkpoints.append(cp)
        self._set_stable_checkpoint(cp.seqNoEnd)
        self.set_watermarks(low_watermark=cp.seqNoEnd)

        # 2. send NewViewCheckpointsApplied
        self._bus.send(
            NewViewCheckpointsApplied(view_no=msg.view_no,
                                      view_changes=msg.view_changes,
                                      checkpoint=msg.checkpoint,
                                      batches=msg.batches))
        return PROCESS, None
コード例 #6
0
class NodeRegHandler(BatchRequestHandler, WriteRequestHandler):
    def __init__(self, database_manager: DatabaseManager):
        BatchRequestHandler.__init__(self, database_manager, POOL_LEDGER_ID)
        WriteRequestHandler.__init__(self, database_manager, NODE,
                                     POOL_LEDGER_ID)

        self.uncommitted_node_reg = []
        self.committed_node_reg = []

        # committed node reg at the beginning of view
        # matches the committed node reg BEFORE the first txn in a view is applied (that is according to the last txn in the last view)
        self.committed_node_reg_at_beginning_of_view = SortedDict()

        # uncommitted node reg at the beginning of view
        # matches the uncommittednode reg BEFORE the first txn in a view is applied (that is according to the last txn in the last view)
        self.uncommitted_node_reg_at_beginning_of_view = SortedDict()

        self._uncommitted = deque()  # type: deque[UncommittedNodeReg]
        self._uncommitted_view_no = 0
        self._committed_view_no = 0

        self.internal_bus = None  # type: InternalBus

    def set_internal_bus(self, internal_bus: InternalBus):
        self.internal_bus = internal_bus

    @property
    def active_node_reg(self):
        if not self.uncommitted_node_reg_at_beginning_of_view:
            return []
        return self.uncommitted_node_reg_at_beginning_of_view.peekitem(-1)[1]

    def on_catchup_finished(self):
        self._load_current_node_reg()
        # we must have node regs for at least last two views
        self._load_last_view_node_reg()
        self.uncommitted_node_reg_at_beginning_of_view = copy.deepcopy(
            self.committed_node_reg_at_beginning_of_view)
        logger.info("Loaded current node registry from the ledger: {}".format(
            self.uncommitted_node_reg))
        logger.info(
            "Current committed node registry for previous views: {}".format(
                sorted(self.committed_node_reg_at_beginning_of_view.items())))
        logger.info(
            "Current uncommitted node registry for previous views: {}".format(
                sorted(
                    self.uncommitted_node_reg_at_beginning_of_view.items())))
        logger.info("Current active node registry: {}".format(
            self.active_node_reg))

    def post_batch_applied(self,
                           three_pc_batch: ThreePcBatch,
                           prev_handler_result=None):
        # Observer case:
        if not self.uncommitted_node_reg and three_pc_batch.node_reg:
            self.uncommitted_node_reg = list(three_pc_batch.node_reg)

        view_no = three_pc_batch.view_no if three_pc_batch.original_view_no is None else three_pc_batch.original_view_no

        # Update active_node_reg to point to node_reg at the end of last view
        if view_no > self._uncommitted_view_no:
            self.uncommitted_node_reg_at_beginning_of_view[view_no] = list(
                self._uncommitted[-1].uncommitted_node_reg) if len(
                    self._uncommitted) > 0 else list(self.committed_node_reg)
            self._uncommitted_view_no = view_no

        self._uncommitted.append(
            UncommittedNodeReg(list(self.uncommitted_node_reg), view_no))

        three_pc_batch.node_reg = list(self.uncommitted_node_reg)

        logger.debug("Applied uncommitted node registry: {}".format(
            self.uncommitted_node_reg))
        logger.debug(
            "Current committed node registry for previous views: {}".format(
                sorted(self.committed_node_reg_at_beginning_of_view.items())))
        logger.debug(
            "Current uncommitted node registry for previous views: {}".format(
                sorted(
                    self.uncommitted_node_reg_at_beginning_of_view.items())))
        logger.debug("Current active node registry: {}".format(
            self.active_node_reg))

    def post_batch_rejected(self, ledger_id, prev_handler_result=None):
        reverted = self._uncommitted.pop()
        if len(self._uncommitted) == 0:
            self.uncommitted_node_reg = list(self.committed_node_reg)
            self._uncommitted_view_no = self._committed_view_no
        else:
            last_uncommitted = self._uncommitted[-1]
            self.uncommitted_node_reg = last_uncommitted.uncommitted_node_reg
            self._uncommitted_view_no = last_uncommitted.view_no

        # find the uncommitted node reg at the beginning of view
        if self._uncommitted_view_no < reverted.view_no:
            self.uncommitted_node_reg_at_beginning_of_view.pop(
                reverted.view_no)

        logger.debug("Reverted uncommitted node registry from {} to {}".format(
            reverted.uncommitted_node_reg, self.uncommitted_node_reg))
        logger.debug(
            "Current committed node registry for previous views: {}".format(
                sorted(self.committed_node_reg_at_beginning_of_view.items())))
        logger.debug(
            "Current uncommitted node registry for previous views: {}".format(
                sorted(
                    self.uncommitted_node_reg_at_beginning_of_view.items())))
        logger.debug("Current active node registry: {}".format(
            self.active_node_reg))

    def commit_batch(self,
                     three_pc_batch: ThreePcBatch,
                     prev_handler_result=None):
        # 1. Update node_reg_at_beginning_of_view first (to match the node reg at the end of last view)
        three_pc_batch_view_no = three_pc_batch.view_no if three_pc_batch.original_view_no is None else three_pc_batch.original_view_no
        if three_pc_batch_view_no > self._committed_view_no:
            self.committed_node_reg_at_beginning_of_view[
                three_pc_batch_view_no] = list(self.committed_node_reg)
            self._committed_view_no = three_pc_batch_view_no

            self._gc_node_reg_at_beginning_of_view(
                self.committed_node_reg_at_beginning_of_view)
            self._gc_node_reg_at_beginning_of_view(
                self.uncommitted_node_reg_at_beginning_of_view)

        # 2. update committed node reg
        prev_committed = self.committed_node_reg
        self.committed_node_reg = self._uncommitted.popleft(
        ).uncommitted_node_reg

        # trigger view change if nodes count changed
        # TODO: create a new message to pass Suspicious events and make ViewChangeTriggerService the only place for
        # view change triggering
        if self.internal_bus and len(prev_committed) != len(
                self.committed_node_reg):
            self.internal_bus.send(
                VoteForViewChange(Suspicions.NODE_COUNT_CHANGED,
                                  three_pc_batch_view_no + 1))

        if prev_committed != self.committed_node_reg:
            logger.info("Committed node registry: {}".format(
                self.committed_node_reg))
            logger.info(
                "Current committed node registry for previous views: {}".
                format(
                    sorted(
                        self.committed_node_reg_at_beginning_of_view.items())))
            logger.info(
                "Current uncommitted node registry for previous views: {}".
                format(
                    sorted(self.uncommitted_node_reg_at_beginning_of_view.
                           items())))
            logger.info("Current active node registry: {}".format(
                self.active_node_reg))
        else:
            logger.debug("Committed node registry: {}".format(
                self.committed_node_reg))
            logger.debug(
                "Current committed node registry for previous views: {}".
                format(
                    sorted(
                        self.committed_node_reg_at_beginning_of_view.items())))
            logger.debug(
                "Current uncommitted node registry for previous views: {}".
                format(
                    sorted(self.uncommitted_node_reg_at_beginning_of_view.
                           items())))
            logger.debug("Current active node registry: {}".format(
                self.active_node_reg))

    def _gc_node_reg_at_beginning_of_view(self, node_reg):
        # make sure that we have node reg for the current and previous view (which can be less than the current for more than 1)
        # Ex.: node_reg_at_beginning_of_view has views {0, 3, 5, 7, 11, 13), committed is now 7, so we need to keep all uncommitted (11, 13),
        # and keep the one from the previous view (5). Views 0 and 3 needs to be deleted.
        committed_view_nos = list(node_reg.keys())
        prev_committed_index = max(committed_view_nos.index(self._committed_view_no) - 1, 0) \
            if self._committed_view_no in node_reg else 0
        for view_no in committed_view_nos[:prev_committed_index]:
            node_reg.pop(view_no, None)

    def apply_request(self, request: Request, batch_ts, prev_result):
        if request.operation.get(TYPE) != NODE:
            return None, None, None

        node_name = request.operation[DATA][ALIAS]
        services = request.operation[DATA].get(SERVICES)

        if services is None:
            return None, None, None

        if node_name not in self.uncommitted_node_reg and VALIDATOR in services:
            # new node added or old one promoted
            self.uncommitted_node_reg.append(node_name)
            logger.info("Changed uncommitted node registry to: {}".format(
                self.uncommitted_node_reg))
        elif node_name in self.uncommitted_node_reg and VALIDATOR not in services:
            # existing node demoted
            self.uncommitted_node_reg.remove(node_name)
            logger.info("Changed uncommitted node registry to: {}".format(
                self.uncommitted_node_reg))

        return None, None, None

    def update_state(self, txn, prev_result, request, is_committed=False):
        pass

    def static_validation(self, request):
        pass

    def additional_dynamic_validation(self, request, req_pp_time):
        pass

    def gen_state_key(self, txn):
        pass

    def _load_current_node_reg(self):
        node_reg = self.__load_current_node_reg_from_audit_ledger()
        if node_reg is None:
            node_reg = self.__load_node_reg_from_pool_ledger()
        self.uncommitted_node_reg = list(node_reg)
        self.committed_node_reg = list(node_reg)

    def _load_last_view_node_reg(self):
        self.committed_node_reg_at_beginning_of_view.clear()

        # 1. check if we have audit ledger at all
        audit_ledger = self.database_manager.get_ledger(AUDIT_LEDGER_ID)
        if not audit_ledger:
            # don't have audit ledger yet, so get aleady loaded values from the pool ledger
            self.committed_node_reg_at_beginning_of_view[0] = list(
                self.uncommitted_node_reg)
            self._committed_view_no = 0
            self._uncommitted_view_no = 0
            return

        # 2. get the first txn in the current view and last txn in the last view
        first_txn_in_this_view, last_txn_in_prev_view = self.__get_first_txn_in_view_from_audit(
            audit_ledger, audit_ledger.get_last_committed_txn())

        # 3. set view_no
        self._committed_view_no = get_payload_data(
            first_txn_in_this_view)[AUDIT_TXN_VIEW_NO]
        self._uncommitted_view_no = self._committed_view_no

        # 4. Use last txn in last view to get the node reg
        # get from pool ledger if there is no txns for last view in audit
        if last_txn_in_prev_view is None:
            node_reg_this_view = self.__load_node_reg_for_first_audit_txn(
                first_txn_in_this_view)
        else:
            node_reg_this_view = list(
                self.__load_node_reg_from_audit_txn(audit_ledger,
                                                    last_txn_in_prev_view))
        self.committed_node_reg_at_beginning_of_view[
            self._committed_view_no] = node_reg_this_view

        # 5. Check if audit ledger has information about 0 view only
        if self._committed_view_no == 0:
            return

        # 5. If audit has just 1 txn for the current view (and this view >0), then
        # get the last view from the pool ledger
        if last_txn_in_prev_view is None:
            # assume last view=0 if we don't know it
            self.committed_node_reg_at_beginning_of_view[0] = list(
                self.__load_node_reg_for_first_audit_txn(
                    first_txn_in_this_view))
            return

        # 6. Get the first audit txn for the last view
        first_txn_in_last_view, last_txn_in_pre_last_view = self.__get_first_txn_in_view_from_audit(
            audit_ledger, last_txn_in_prev_view)

        # 7. Use last txn in the view before the last one to get the node reg
        # get from pool ledger if there is no txns for view before the last one in audit
        if last_txn_in_pre_last_view is None:
            node_reg_last_view = self.__load_node_reg_for_first_audit_txn(
                first_txn_in_last_view)
        else:
            node_reg_last_view = list(
                self.__load_node_reg_from_audit_txn(audit_ledger,
                                                    last_txn_in_pre_last_view))
        last_view_no = get_payload_data(
            first_txn_in_last_view)[AUDIT_TXN_VIEW_NO]
        self.committed_node_reg_at_beginning_of_view[
            last_view_no] = node_reg_last_view

    def __load_node_reg_from_pool_ledger(self, to=None):
        node_reg = []
        for _, txn in self.ledger.getAllTxn(to=to):
            if get_type(txn) != NODE:
                continue
            txn_data = get_payload_data(txn)
            node_name = txn_data[DATA][ALIAS]
            services = txn_data[DATA].get(SERVICES)

            if services is None:
                continue

            if node_name not in node_reg and VALIDATOR in services:
                # new node added or old one promoted
                node_reg.append(node_name)
            elif node_name in node_reg and VALIDATOR not in services:
                # existing node demoted
                node_reg.remove(node_name)
        return node_reg

    # TODO: create a helper class to get data from Audit Ledger
    def __load_current_node_reg_from_audit_ledger(self):
        audit_ledger = self.database_manager.get_ledger(AUDIT_LEDGER_ID)
        if not audit_ledger:
            return None

        last_txn = audit_ledger.get_last_committed_txn()
        last_txn_node_reg = get_payload_data(last_txn).get(AUDIT_TXN_NODE_REG)
        if last_txn_node_reg is None:
            return None

        if isinstance(last_txn_node_reg, int):
            seq_no = get_seq_no(last_txn) - last_txn_node_reg
            audit_txn_for_seq_no = audit_ledger.getBySeqNo(seq_no)
            last_txn_node_reg = get_payload_data(audit_txn_for_seq_no).get(
                AUDIT_TXN_NODE_REG)

        if last_txn_node_reg is None:
            return None
        return last_txn_node_reg

    def __load_node_reg_from_audit_txn(self, audit_ledger, audit_txn):
        audit_txn_data = get_payload_data(audit_txn)

        # Get the node reg from audit txn
        node_reg = audit_txn_data.get(AUDIT_TXN_NODE_REG)
        if node_reg is None:
            return self.__load_node_reg_for_first_audit_txn(audit_txn)

        if isinstance(node_reg, int):
            seq_no = get_seq_no(audit_txn) - node_reg
            prev_audit_txn = audit_ledger.getBySeqNo(seq_no)
            node_reg = get_payload_data(prev_audit_txn).get(AUDIT_TXN_NODE_REG)

        if node_reg is None:
            return self.__load_node_reg_for_first_audit_txn(audit_txn)

        return node_reg

    def __get_first_txn_in_view_from_audit(self, audit_ledger,
                                           this_view_first_txn):
        '''
        :param audit_ledger: audit ledger
        :param this_view_first_txn: a txn from the current view
        :return: the first txn in this view and the last txn in the previous view (if amy, otherwise None)
        '''
        this_txn_view_no = get_payload_data(this_view_first_txn).get(
            AUDIT_TXN_VIEW_NO)

        prev_view_last_txn = None
        while True:
            txn_primaries = get_payload_data(this_view_first_txn).get(
                AUDIT_TXN_PRIMARIES)
            if isinstance(txn_primaries, int):
                seq_no = get_seq_no(this_view_first_txn) - txn_primaries
                this_view_first_txn = audit_ledger.getBySeqNo(seq_no)
            this_txn_seqno = get_seq_no(this_view_first_txn)
            if this_txn_seqno <= 1:
                break
            prev_view_last_txn = audit_ledger.getBySeqNo(this_txn_seqno - 1)
            prev_txn_view_no = get_payload_data(prev_view_last_txn).get(
                AUDIT_TXN_VIEW_NO)

            if this_txn_view_no != prev_txn_view_no:
                break

            this_view_first_txn = prev_view_last_txn
            prev_view_last_txn = None

        return this_view_first_txn, prev_view_last_txn

    def __load_node_reg_for_first_audit_txn(self, first_audit_txn):
        # If this is the first txn in the audit ledger, so that we don't know a full history,
        # then get node reg from the pool ledger
        audit_txn_data = get_payload_data(first_audit_txn)
        genesis_pool_ledger_size = audit_txn_data[AUDIT_TXN_LEDGERS_SIZE][
            POOL_LEDGER_ID]
        return self.__load_node_reg_from_pool_ledger(
            to=genesis_pool_ledger_size)