예제 #1
0
def test_stashing_router_can_stash_and_sort_messages():
    calls = []

    def handler(message: SomeMessage):
        calls.append(message)
        return STASH

    def sort_key(message: SomeMessage):
        return message.int_field

    bus = InternalBus()
    router = StashingRouter(10)
    router.set_sorted_stasher(STASH, key=sort_key)
    router.subscribe(SomeMessage, handler)
    router.subscribe_to(bus)

    messages = [create_some_message() for _ in range(10)]
    for msg in messages:
        bus.send(msg)

    assert calls == messages

    calls.clear()
    router.process_all_stashed()
    assert calls == sorted(messages, key=sort_key)
예제 #2
0
def test_process_all_stashed_doesnt_do_anything_when_there_are_no_items_in_stash(
):
    handler = Mock(return_value=(PROCESS, ""))

    bus = InternalBus()
    router = StashingRouter(10, buses=[bus])
    router.subscribe(SomeMessage, handler)

    router.process_all_stashed()
    handler.assert_not_called()

    message = create_some_message()
    bus.send(message, 'hello')
    handler.assert_called_once_with(message, 'hello')

    router.process_all_stashed()
    handler.assert_called_once_with(message, 'hello')
예제 #3
0
def test_stashing_router_can_stash_messages():
    stash_count = 3
    calls = []

    def handler(msg):
        nonlocal stash_count
        calls.append(msg)
        if stash_count > 0:
            stash_count -= 1
            return STASH, "reason"
        else:
            return None, None

    bus = InternalBus()
    router = StashingRouter(10, buses=[bus])
    router.subscribe(SomeMessage, handler)

    msg_a = create_some_message()
    msg_b = create_some_message()
    bus.send(msg_a)
    bus.send(msg_b)
    assert router.stash_size() == 2
    assert calls == [msg_a, msg_b]

    router.process_all_stashed()
    assert router.stash_size() == 1
    assert calls == [msg_a, msg_b, msg_a, msg_b]

    router.process_all_stashed()
    assert router.stash_size() == 0
    assert calls == [msg_a, msg_b, msg_a, msg_b, msg_a]

    router.process_all_stashed()
    assert router.stash_size() == 0
    assert calls == [msg_a, msg_b, msg_a, msg_b, msg_a]
예제 #4
0
def test_stashing_router_can_stash_messages_with_metadata():
    stash_count = 3
    calls = []

    def handler(msg, frm):
        nonlocal stash_count
        calls.append((msg, frm))
        if stash_count > 0:
            stash_count -= 1
            return STASH

    bus = InternalBus()
    router = StashingRouter(10)
    router.subscribe(SomeMessage, handler)
    router.subscribe_to(bus)

    msg_a = create_some_message()
    msg_b = create_some_message()
    bus.send(msg_a, 'A')
    bus.send(msg_b, 'B')
    assert router.stash_size() == 2
    assert calls == [(msg_a, 'A'), (msg_b, 'B')]

    router.process_all_stashed()
    assert router.stash_size() == 1
    assert calls == [(msg_a, 'A'), (msg_b, 'B'), (msg_a, 'A'), (msg_b, 'B')]

    router.process_all_stashed()
    assert router.stash_size() == 0
    assert calls == [(msg_a, 'A'), (msg_b, 'B'), (msg_a, 'A'), (msg_b, 'B'), (msg_a, 'A')]

    router.process_all_stashed()
    assert router.stash_size() == 0
    assert calls == [(msg_a, 'A'), (msg_b, 'B'), (msg_a, 'A'), (msg_b, 'B'), (msg_a, 'A')]
예제 #5
0
def test_stashing_router_can_stash_messages_with_different_reasons():
    calls = []

    def handler(message: SomeMessage):
        calls.append(message)
        if message.int_field % 2 == 0:
            return STASH + 0, "reason"
        else:
            return STASH + 1, "reason"

    bus = InternalBus()
    router = StashingRouter(10, buses=[bus])
    router.subscribe(SomeMessage, handler)

    messages = [create_some_message() for _ in range(10)]
    for msg in messages:
        bus.send(msg)
    assert router.stash_size() == len(messages)
    assert router.stash_size(STASH +
                             0) + router.stash_size(STASH +
                                                    1) == router.stash_size()

    calls.clear()
    router.process_all_stashed()
    assert router.stash_size() == len(messages)
    assert calls == sorted(messages, key=lambda m: m.int_field % 2)

    calls.clear()
    router.process_all_stashed(STASH + 0)
    assert router.stash_size() == len(messages)
    assert router.stash_size(STASH + 0) == len(calls)
    assert all(msg.int_field % 2 == 0 for msg in calls)
    assert all(msg in messages for msg in calls)

    calls.clear()
    router.process_all_stashed(STASH + 1)
    assert router.stash_size() == len(messages)
    assert router.stash_size(STASH + 1) == len(calls)
    assert all(msg.int_field % 2 != 0 for msg in calls)
    assert all(msg in messages for msg in calls)
예제 #6
0
class ViewChangeService:
    def __init__(self, data: ConsensusSharedData, timer: TimerService, bus: InternalBus, network: ExternalBus):
        self._config = getConfig()
        self._logger = getlogger()

        self._data = data
        self._new_view_builder = NewViewBuilder(self._data)
        self._timer = timer
        self._bus = bus
        self._network = network
        self._router = StashingRouter(self._config.VIEW_CHANGE_SERVICE_STASH_LIMIT)
        self._votes = ViewChangeVotesForView(self._data.quorums)
        self._new_view = None  # type: Optional[NewView]

        self._router.subscribe(ViewChange, self.process_view_change_message)
        self._router.subscribe(ViewChangeAck, self.process_view_change_ack_message)
        self._router.subscribe(NewView, self.process_new_view_message)
        self._router.subscribe_to(network)

        self._old_prepared = {}  # type: Dict[int, BatchID]
        self._old_preprepared = {}  # type: Dict[int, List[BatchID]]
        self._primaries_selector = RoundRobinPrimariesSelector()

    def __repr__(self):
        return self._data.name

    def start_view_change(self, view_no: Optional[int] = None):
        if view_no is None:
            view_no = self._data.view_no + 1

        self._clear_old_batches(self._old_prepared)
        self._clear_old_batches(self._old_preprepared)

        for batch_id in self._data.prepared:
            self._old_prepared[batch_id.pp_seq_no] = batch_id
        prepared = sorted([tuple(bid) for bid in self._old_prepared.values()])

        for new_bid in self._data.preprepared:
            pretenders = self._old_preprepared.get(new_bid.pp_seq_no, [])
            pretenders = [bid for bid in pretenders
                          if bid.pp_digest != new_bid.pp_digest]
            pretenders.append(new_bid)
            self._old_preprepared[new_bid.pp_seq_no] = pretenders
        preprepared = sorted([tuple(bid) for bids in self._old_preprepared.values() for bid in bids])

        self._data.view_no = view_no
        self._data.waiting_for_new_view = True
        self._data.primary_name = self._primaries_selector.select_primaries(view_no=self._data.view_no,
                                                                            instance_count=self._data.quorums.f + 1,
                                                                            validators=self._data.validators)[0]
        self._data.preprepared.clear()
        self._data.prepared.clear()
        self._votes.clear()
        self._new_view = None

        vc = ViewChange(
            viewNo=self._data.view_no,
            stableCheckpoint=self._data.stable_checkpoint,
            prepared=prepared,
            preprepared=preprepared,
            checkpoints=list(self._data.checkpoints)
        )
        self._network.send(vc)
        self._votes.add_view_change(vc, self._data.name)

        self._router.process_all_stashed()

    def process_view_change_message(self, msg: ViewChange, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._votes.add_view_change(msg, frm)

        if self._data.is_primary:
            self._send_new_view_if_needed()
            return

        vca = ViewChangeAck(
            viewNo=msg.viewNo,
            name=frm,
            digest=view_change_digest(msg)
        )
        self._network.send(vca, self._data.primary_name)

        self._finish_view_change_if_needed()

    def process_view_change_ack_message(self, msg: ViewChangeAck, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        if not self._data.is_primary:
            return

        self._votes.add_view_change_ack(msg, frm)
        self._send_new_view_if_needed()

    def process_new_view_message(self, msg: NewView, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._new_view = msg

        self._finish_view_change_if_needed()

    def _validate(self, msg: Union[ViewChange, ViewChangeAck, NewView], frm: str) -> int:
        # TODO: Proper validation

        if msg.viewNo < self._data.view_no:
            return DISCARD

        if msg.viewNo == self._data.view_no and not self._data.waiting_for_new_view:
            return DISCARD

        if msg.viewNo > self._data.view_no:
            return STASH

        return PROCESS

    def _send_new_view_if_needed(self):
        confirmed_votes = self._votes.confirmed_votes
        if not self._data.quorums.view_change.is_reached(len(confirmed_votes)):
            return

        view_changes = [self._votes.get_view_change(*v) for v in confirmed_votes]
        cp = self._new_view_builder.calc_checkpoint(view_changes)
        if cp is None:
            return

        batches = self._new_view_builder.calc_batches(cp, view_changes)
        if batches is None:
            return

        nv = NewView(
            viewNo=self._data.view_no,
            viewChanges=confirmed_votes,
            checkpoint=cp,
            batches=batches
        )
        self._network.send(nv)
        self._new_view = nv
        self._finish_view_change(cp, batches)

    def _finish_view_change_if_needed(self):
        if self._new_view is None:
            return

        view_changes = []
        for name, vc_digest in self._new_view.viewChanges:
            vc = self._votes.get_view_change(name, vc_digest)
            # We don't have needed ViewChange, so we cannot validate NewView
            if vc is None:
                return
            view_changes.append(vc)

        cp = self._new_view_builder.calc_checkpoint(view_changes)
        if cp is None or cp != self._new_view.checkpoint:
            # New primary is malicious
            self.start_view_change()
            assert False  # TODO: Test debugging purpose
            return

        batches = self._new_view_builder.calc_batches(cp, view_changes)
        if batches != self._new_view.batches:
            # New primary is malicious
            self.start_view_change()
            assert False  # TODO: Test debugging purpose
            return

        self._finish_view_change(cp, batches)

    def _finish_view_change(self, cp: Checkpoint, batches: List[BatchID]):
        # Update checkpoint
        # TODO: change to self._bus.send(FinishViewChange(cp)) in scope of the task INDY-2179
        self._data.stable_checkpoint = cp.seqNoEnd
        self._data.checkpoints = [old_cp for old_cp in self._data.checkpoints if old_cp.seqNoEnd > cp.seqNoEnd]
        self._data.checkpoints.append(cp)

        # Update batches
        # TODO: Actually we'll need to retrieve preprepares by ID from somewhere
        self._data.preprepared = batches

        # We finished a view change!
        self._data.waiting_for_new_view = False

    def _clear_old_batches(self, batches: Dict[int, Any]):
        for pp_seq_no in list(batches.keys()):
            if pp_seq_no <= self._data.stable_checkpoint:
                del batches[pp_seq_no]
예제 #7
0
class ViewChangeService:
    def __init__(self, data: ConsensusSharedData, timer: TimerService,
                 bus: InternalBus, network: ExternalBus):
        self._config = getConfig()
        self._logger = getlogger()

        self._data = data
        self._timer = timer
        self._bus = bus
        self._network = network
        self._router = StashingRouter(
            self._config.VIEW_CHANGE_SERVICE_STASH_LIMIT)
        self._votes = ViewChangeVotesForView(self._data.quorums)
        self._new_view = None  # type: Optional[NewView]

        self._router.subscribe(ViewChange, self.process_view_change_message)
        self._router.subscribe(ViewChangeAck,
                               self.process_view_change_ack_message)
        self._router.subscribe(NewView, self.process_new_view_message)
        self._router.subscribe_to(network)

        self._old_prepared = {}  # type: Dict[int, BatchID]
        self._old_preprepared = {}  # type: Dict[int, List[BatchID]]

    def start_view_change(self, view_no: Optional[int] = None):
        if view_no is None:
            view_no = self._data.view_no + 1

        self._clear_old_batches(self._old_prepared)
        self._clear_old_batches(self._old_preprepared)

        for pp in self._data.prepared:
            self._old_prepared[pp.ppSeqNo] = self._batch_id(pp)
        prepared = sorted([tuple(bid) for bid in self._old_prepared.values()])

        for pp in self._data.preprepared:
            new_bid = self._batch_id(pp)
            pretenders = self._old_preprepared.get(pp.ppSeqNo, [])
            pretenders = [
                bid for bid in pretenders if bid.pp_digest != new_bid.pp_digest
            ]
            pretenders.append(new_bid)
            self._old_preprepared[pp.ppSeqNo] = pretenders
        preprepared = sorted([
            tuple(bid) for bids in self._old_preprepared.values()
            for bid in bids
        ])

        self._data.view_no = view_no
        self._data.waiting_for_new_view = True
        self._data.primary_name = self._find_primary(self._data.validators,
                                                     self._data.view_no)
        self._data.preprepared.clear()
        self._data.prepared.clear()
        self._votes.clear()
        self._new_view = None

        vc = ViewChange(viewNo=self._data.view_no,
                        stableCheckpoint=self._data.stable_checkpoint,
                        prepared=prepared,
                        preprepared=preprepared,
                        checkpoints=list(self._data.checkpoints))
        self._network.send(vc)
        self._votes.add_view_change(vc, self._data.name)

        self._router.process_all_stashed()

    def process_view_change_message(self, msg: ViewChange, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._votes.add_view_change(msg, frm)

        if self._data.is_primary:
            self._send_new_view_if_needed()
            return

        vca = ViewChangeAck(viewNo=msg.viewNo,
                            name=frm,
                            digest=view_change_digest(msg))
        self._network.send(vca, self._data.primary_name)

        self._finish_view_change_if_needed()

    def process_view_change_ack_message(self, msg: ViewChangeAck, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        if not self._data.is_primary:
            return

        self._votes.add_view_change_ack(msg, frm)
        self._send_new_view_if_needed()

    def process_new_view_message(self, msg: NewView, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._new_view = msg

        self._finish_view_change_if_needed()

    @staticmethod
    def _find_primary(validators: List[str], view_no: int) -> str:
        return validators[view_no % len(validators)]

    def _validate(self, msg: Union[ViewChange, ViewChangeAck, NewView],
                  frm: str) -> int:
        # TODO: Proper validation

        if msg.viewNo < self._data.view_no:
            return DISCARD

        if msg.viewNo == self._data.view_no and not self._data.waiting_for_new_view:
            return DISCARD

        if msg.viewNo > self._data.view_no:
            return STASH

        return PROCESS

    def _send_new_view_if_needed(self):
        confirmed_votes = self._votes.confirmed_votes
        if not self._data.quorums.view_change.is_reached(len(confirmed_votes)):
            return

        view_changes = [
            self._votes.get_view_change(*v) for v in confirmed_votes
        ]
        cp = self._calc_checkpoint(view_changes)
        if cp is None:
            return

        batches = self._calc_batches(cp, view_changes)

        nv = NewView(viewNo=self._data.view_no,
                     viewChanges=confirmed_votes,
                     checkpoint=cp,
                     batches=batches)
        self._network.send(nv)
        self._new_view = nv
        self._finish_view_change(cp, batches)

    def _finish_view_change_if_needed(self):
        if self._new_view is None:
            return

        view_changes = []
        for name, vc_digest in self._new_view.viewChanges:
            vc = self._votes.get_view_change(name, vc_digest)
            # We don't have needed ViewChange, so we cannot validate NewView
            if vc is None:
                return
            view_changes.append(vc)

        cp = self._calc_checkpoint(view_changes)
        if cp is None or cp != self._new_view.checkpoint:
            # New primary is malicious
            self.start_view_change()
            assert False  # TODO: Test debugging purpose
            return

        batches = self._calc_batches(cp, view_changes)
        if batches != self._new_view.batches:
            # New primary is malicious
            self.start_view_change()
            assert False  # TODO: Test debugging purpose
            return

        self._finish_view_change(cp, batches)

    def _finish_view_change(self, cp: Checkpoint, batches: List[BatchID]):
        # Update checkpoint
        self._data.stable_checkpoint = cp.seqNoEnd
        self._data.checkpoints = [
            old_cp for old_cp in self._data.checkpoints
            if old_cp.seqNoEnd > cp.seqNoEnd
        ]
        self._data.checkpoints.append(cp)

        # Update batches
        # TODO: Actually we'll need to retrieve preprepares by ID from somewhere
        self._data.preprepared = batches

        # We finished a view change!
        self._data.waiting_for_new_view = False

    def _clear_old_batches(self, batches: Dict[int, Any]):
        for pp_seq_no in list(batches.keys()):
            if pp_seq_no <= self._data.stable_checkpoint:
                del batches[pp_seq_no]

    @staticmethod
    def _batch_id(batch: PrePrepare):
        return BatchID(batch.viewNo, batch.ppSeqNo, batch.digest)

    def _calc_checkpoint(self, vcs: List[ViewChange]) -> Optional[Checkpoint]:
        checkpoints = []
        for cur_vc in vcs:
            for cur_cp in cur_vc.checkpoints:
                # Don't add checkpoint to pretending ones if it is already there
                if cur_cp in checkpoints:
                    continue

                # Don't add checkpoint to pretending ones if too many nodes already stabilized it
                # TODO: Should we take into account view_no as well?
                stable_checkpoint_not_higher = [
                    vc for vc in vcs if cur_cp.seqNoEnd >= vc.stableCheckpoint
                ]
                if not self._data.quorums.strong.is_reached(
                        len(stable_checkpoint_not_higher)):
                    continue

                # Don't add checkpoint to pretending ones if not enough nodes have it
                have_checkpoint = [
                    vc for vc in vcs if cur_cp in vc.checkpoints
                ]
                if not self._data.quorums.weak.is_reached(
                        len(have_checkpoint)):
                    continue

                # All checks passed, this is a valid candidate checkpoint
                checkpoints.append(cur_cp)

        highest_cp = None
        for cp in checkpoints:
            # TODO: Should we take into account view_no as well?
            if highest_cp is None or cp.seqNoEnd > highest_cp.seqNoEnd:
                highest_cp = cp

        return highest_cp

    def _calc_batches(self, cp: Checkpoint,
                      vcs: List[ViewChange]) -> List[BatchID]:
        # TODO: Optimize this
        batches = set()
        for vc in vcs:
            for _bid in vc.prepared:
                bid = BatchID(*_bid)
                if bid in batches:
                    continue
                if self._is_batch_prepared(bid, cp, vcs):
                    batches.add(bid)

            for _bid in vc.preprepared:
                bid = BatchID(*_bid)
                if bid in batches:
                    continue
                if self._is_batch_preprepared(bid, cp, vcs):
                    batches.add(bid)

        return sorted(batches)

    def _is_batch_prepared(self, bid: BatchID, cp: Checkpoint,
                           vcs: List[ViewChange]) -> bool:
        if not self._is_inside_watermarks(bid, cp):
            return False

        def check(vc: ViewChange):
            if bid.pp_seq_no <= vc.stableCheckpoint:
                return False

            for _some_bid in vc.prepared:
                some_bid = BatchID(*_some_bid)
                if some_bid.pp_seq_no != bid.pp_seq_no:
                    continue
                if some_bid.view_no < bid.view_no:
                    return True
                return some_bid == bid

            return False

        prepared_witnesses = sum(1 for vc in vcs if check(vc))
        return self._data.quorums.strong.is_reached(prepared_witnesses)

    def _is_batch_preprepared(self, bid: BatchID, cp: Checkpoint,
                              vcs: List[ViewChange]) -> bool:
        if not self._is_inside_watermarks(bid, cp):
            return False

        def check(vc: ViewChange):
            for _some_bid in vc.preprepared:
                some_bid = BatchID(*_some_bid)
                if some_bid.pp_seq_no != bid.pp_seq_no:
                    continue
                if some_bid.pp_digest != bid.pp_digest:
                    continue
                if some_bid.view_no >= bid.view_no:
                    return True

            return False

        preprepared_witnesses = sum(1 for vc in vcs if check(vc))
        return self._data.quorums.weak.is_reached(preprepared_witnesses)

    def _is_inside_watermarks(self, bid: BatchID, cp: Checkpoint) -> bool:
        # TODO: Get log size from ConsensusDataProvider
        return cp.seqNoEnd < bid.pp_seq_no <= cp.seqNoEnd + 300
예제 #8
0
class ViewChangeService:
    def __init__(self, data: ConsensusDataProvider, timer: TimerService, bus: InternalBus, network: ExternalBus):
        self._config = getConfig()
        self._logger = getlogger()

        self._data = data
        self._timer = timer
        self._bus = bus
        self._network = network
        self._stasher = StashingRouter(self._config.VIEW_CHANGE_SERVICE_STASH_LIMIT)
        self._votes = ViewChangeVotesForView(self._data.quorums)

        self._stasher.subscribe(ViewChange, self.process_view_change_message)
        self._stasher.subscribe(ViewChangeAck, self.process_view_change_ack_message)
        self._stasher.subscribe(NewView, self.process_new_view_message)
        self._stasher.subscribe_to(network)

    def start_view_change(self, view_no: Optional[int] = None):
        if view_no is None:
            view_no = self._data.view_no + 1

        # TODO: Calculate
        prepared = []
        preprepared = []

        self._data.view_no = view_no
        self._data.waiting_for_new_view = True
        self._data.primary_name = self._find_primary(self._data.validators, self._data.view_no)
        self._votes.clear()

        vc = ViewChange(
            viewNo=self._data.view_no,
            stableCheckpoint=self._data.stable_checkpoint,
            prepared=prepared,
            preprepared=preprepared,
            checkpoints=self._data.checkpoints
        )
        self._network.send(vc)

        self._stasher.process_all_stashed()

    def process_view_change_message(self, msg: ViewChange, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._votes.add_view_change(msg, frm)

        if self._data.is_primary:
            self._send_new_view_if_needed()
            return

        vca = ViewChangeAck(
            viewNo=msg.viewNo,
            name=frm,
            digest=view_change_digest(msg)
        )
        self._network.send(vca, self._data.primary_name)

    def process_view_change_ack_message(self, msg: ViewChangeAck, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        if not self._data.is_primary:
            return

        self._votes.add_view_change_ack(msg, frm)
        self._send_new_view_if_needed()

    def process_new_view_message(self, msg: NewView, frm: str):
        result = self._validate(msg, frm)
        if result != PROCESS:
            return result

        self._data.waiting_for_new_view = False

    @staticmethod
    def _find_primary(validators: List[str], view_no: int) -> str:
        return validators[view_no % len(validators)]

    def _is_primary(self, view_no: int) -> bool:
        # TODO: Do we really need this?
        return self._find_primary(self._data.validators, view_no) == self._data.name

    def _validate(self, msg: Union[ViewChange, ViewChangeAck, NewView], frm: str) -> int:
        # TODO: Proper validation

        if msg.viewNo < self._data.view_no:
            return DISCARD

        if msg.viewNo == self._data.view_no and not self._data.waiting_for_new_view:
            return DISCARD

        if msg.viewNo > self._data.view_no:
            return STASH

        return PROCESS

    def _send_new_view_if_needed(self):
        if not self._votes.has_view_change_quorum:
            return

        nv = NewView(
            viewNo=self._data.view_no,
            viewChanges=self._votes.confirmed_votes,
            checkpoint=None,
            preprepares=[]
        )
        self._network.send(nv)
        self._data.waiting_for_new_view = False