Exemple #1
0
 def setUp(self, tmpdir) -> None:
     tmp_val = tmpdir
     self.block_store = LMDBLockStore(str(tmp_val))
     self.chain_factory = ChainFactory()
     self.dbms = DBManager(self.chain_factory, self.block_store)
     yield
     self.dbms.close()
Exemple #2
0
 def setUp2(self, tmpdir) -> None:
     tmp_val = tmpdir
     self.block_store2 = LMDBLockStore(str(tmp_val))
     self.chain_factory2 = ChainFactory()
     self.dbms2 = DBManager(self.chain_factory2, self.block_store2)
     yield
     try:
         self.dbms2.close()
         tmp_val.remove()
     except FileNotFoundError:
         pass
Exemple #3
0
class TestIntegrationDBManager:
    @pytest.fixture(autouse=True)
    def setUp(self, tmpdir) -> None:
        tmp_val = tmpdir
        self.block_store = LMDBLockStore(str(tmp_val))
        self.chain_factory = ChainFactory()
        self.dbms = DBManager(self.chain_factory, self.block_store)
        yield
        self.dbms.close()

    @pytest.fixture(autouse=True)
    def setUp2(self, tmpdir) -> None:
        tmp_val = tmpdir
        self.block_store2 = LMDBLockStore(str(tmp_val))
        self.chain_factory2 = ChainFactory()
        self.dbms2 = DBManager(self.chain_factory2, self.block_store2)
        yield
        try:
            self.dbms2.close()
            tmp_val.remove()
        except FileNotFoundError:
            pass

    def test_get_tx_blob(self):
        self.test_block = FakeBlock()
        packed_block = self.test_block.pack()
        self.dbms.add_block(packed_block, self.test_block)
        self.tx_blob = self.test_block.transaction

        assert (
            self.dbms.get_tx_blob_by_dot(
                self.test_block.com_id, self.test_block.com_dot
            )
            == self.tx_blob
        )
        assert (
            self.dbms.get_block_blob_by_dot(
                self.test_block.com_id, self.test_block.com_dot
            )
            == packed_block
        )

    def test_add_notify_block_one_chain(self, create_batches, insert_function):
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                assert (len(self.val_dots) == 0 and dot[0] == 1) or dot[
                    0
                ] == self.val_dots[-1][0] + 1
                self.val_dots.append(dot)

        blks = create_batches(num_batches=1, num_blocks=100)
        com_id = blks[0][0].com_id
        self.dbms.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0]))
        assert len(self.val_dots) == 100

    def test_add_notify_block_with_conflicts(self, create_batches, insert_function):
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                self.val_dots.append(dot)

        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id
        self.dbms.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0][:20]))
        wrap_iterate(insert_function(self.dbms, blks[1][:40]))
        wrap_iterate(insert_function(self.dbms, blks[0][20:60]))
        wrap_iterate(insert_function(self.dbms, blks[1][40:]))
        wrap_iterate(insert_function(self.dbms, blks[0][60:]))

        assert len(self.val_dots) == 200

    def test_blocks_by_frontier_diff(self, create_batches, insert_function):
        # init chain
        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id

        wrap_iterate(insert_function(self.dbms, blks[0][:50]))
        wrap_iterate(insert_function(self.dbms2, blks[1][:50]))

        front = self.dbms.get_chain(com_id).frontier
        front_diff = self.dbms2.get_chain(com_id).reconcile(front)
        vals_request = set()

        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            com_id, front_diff, vals_request
        )
        assert len(blobs) == 41

    def reconcile_round(self, com_id):
        front = self.dbms.get_chain(com_id).frontier
        front_diff = self.dbms2.get_chain(com_id).reconcile(front)
        vals_request = set()
        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            com_id, front_diff, vals_request
        )
        return blobs

    def test_blocks_by_fdiff_with_holes(self, create_batches, insert_function):
        # init chain
        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                self.val_dots.append(dot)

        self.dbms2.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0][:50]))
        wrap_iterate(insert_function(self.dbms2, blks[1][:20]))
        wrap_iterate(insert_function(self.dbms2, blks[1][40:60]))

        assert len(self.val_dots) == 20
        blobs = self.reconcile_round(com_id)
        assert len(blobs) == 41

        for b in blobs:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))

        assert len(self.val_dots) == 20
        blobs2 = self.reconcile_round(com_id)
        assert len(blobs2) == 8
        for b in blobs2:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))

        assert len(self.val_dots) == 20
        blobs2 = self.reconcile_round(com_id)
        assert len(blobs2) == 1
        for b in blobs2:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))
        assert len(self.val_dots) == 70
Exemple #4
0
def test_hot_start_db(tmpdir):
    tmp_val = tmpdir
    block_store = LMDBLockStore(str(tmp_val))
    chain_factory = ChainFactory()
    dbms = DBManager(chain_factory, block_store)

    test_block = FakeBlock()
    packed_block = test_block.pack()
    dbms.add_block(packed_block, test_block)
    tx_blob = test_block.transaction

    assert dbms.get_tx_blob_by_dot(test_block.com_id, test_block.com_dot) == tx_blob
    assert (
        dbms.get_block_blob_by_dot(test_block.com_id, test_block.com_dot)
        == packed_block
    )
    front = dbms.get_chain(test_block.com_id).frontier
    dbms.close()

    block_store2 = LMDBLockStore(str(tmp_val))
    chain_factory2 = ChainFactory()
    dbms2 = DBManager(chain_factory2, block_store2)

    assert dbms2.get_tx_blob_by_dot(test_block.com_id, test_block.com_dot) == tx_blob
    assert (
        dbms2.get_block_blob_by_dot(test_block.com_id, test_block.com_dot)
        == packed_block
    )

    assert dbms2.get_chain(test_block.com_id).frontier == front

    dbms2.close()
    tmp_val.remove()
Exemple #5
0
 def setUp(self) -> None:
     self.block_store = MockBlockStore()
     self.chain_factory = MockChainFactory()
     self.dbms = DBManager(self.chain_factory, self.block_store)
Exemple #6
0
class TestDBManager:
    @pytest.fixture(autouse=True)
    def setUp(self) -> None:
        self.block_store = MockBlockStore()
        self.chain_factory = MockChainFactory()
        self.dbms = DBManager(self.chain_factory, self.block_store)

    @pytest.fixture
    def std_vals(self):
        self.chain_id = b"chain_id"
        self.block_dot = Dot((3, ShortKey("808080")))
        self.block_dot_encoded = encode_raw(self.block_dot)
        self.dot_id = self.chain_id + self.block_dot_encoded

        self.test_hash = b"test_hash"
        self.tx_blob = b"tx_blob"
        self.block_blob = b"block_blob"

        self.test_block = FakeBlock()
        self.pers = self.test_block.public_key
        self.com_id = self.test_block.com_id

    def test_get_tx_blob(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore,
            "get_hash_by_dot",
            lambda _, dot_bytes: self.test_hash if dot_bytes == self.dot_id else None,
        )
        monkeypatch.setattr(
            MockBlockStore,
            "get_tx_by_hash",
            lambda _, tx_hash: self.tx_blob if tx_hash == self.test_hash else None,
        )

        assert (
            self.dbms.get_tx_blob_by_dot(self.chain_id, self.block_dot) == self.tx_blob
        )

    def test_get_block_blob(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore,
            "get_hash_by_dot",
            lambda _, dot_bytes: self.test_hash if dot_bytes == self.dot_id else None,
        )
        monkeypatch.setattr(
            MockBlockStore,
            "get_block_by_hash",
            lambda _, blob_hash: self.block_blob
            if blob_hash == self.test_hash
            else None,
        )

        assert (
            self.dbms.get_block_blob_by_dot(self.chain_id, self.block_dot)
            == self.block_blob
        )

    def test_last_frontiers(self, monkeypatch, std_vals):
        new_frontier = Frontier(terminal=((3, b"2123"),), holes=(), inconsistencies=())
        self.dbms.store_last_frontier(self.chain_id, "peer_1", new_frontier)
        front = self.dbms.get_last_frontier(self.chain_id, "peer_1")
        assert front == new_frontier

    def test_add_notify_block(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockChain,
            "add_block",
            lambda _, block_links, seq_num, block_hash: ["dot1", "dot2"]
            if block_hash == self.test_block.hash
            else None,
        )

        def chain_dots_tester(chain_id, dots):
            assert chain_id in (self.test_block.public_key, self.test_block.com_id)
            assert dots == ["dot1", "dot2"]

        self.dbms.add_observer(ChainTopic.ALL, chain_dots_tester)
        self.dbms.add_block(self.block_blob, self.test_block)

    def test_blocks_by_frontier_diff(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore, "get_hash_by_dot", lambda _, dot_bytes: bytes(dot_bytes)
        )
        monkeypatch.setattr(
            MockBlockStore, "get_block_by_hash", lambda _, blob_hash: bytes(blob_hash)
        )
        monkeypatch.setattr(
            MockChain, "get_dots_by_seq_num", lambda _, seq_num: ("dot1", "dot2")
        )

        # init chain
        chain_id = self.chain_id
        self.dbms.chains[chain_id] = MockChain()
        frontier_diff = FrontierDiff(Ranges(((1, 2),)), {(1, ShortKey("efef")): {}})
        vals_to_request = set()

        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            chain_id, frontier_diff, vals_to_request
        )
        assert len(vals_to_request) == 0
        assert len(blobs) == 3

    def test_blocks_frontier_with_extra_request(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore, "get_hash_by_dot", lambda _, dot_bytes: self.test_hash
        )
        monkeypatch.setattr(
            MockBlockStore, "get_block_by_hash", lambda _, blob_hash: self.block_blob
        )
        monkeypatch.setattr(
            MockChain, "get_dots_by_seq_num", lambda _, seq_num: ("dot1", "dot2")
        )

        local_vers = {2: {"ef1"}, 7: {"ef1"}}

        monkeypatch.setattr(
            MockChain,
            "get_all_short_hash_by_seq_num",
            lambda _, seq_num: local_vers.get(seq_num),
        )
        monkeypatch.setattr(
            MockChain,
            "get_next_links",
            lambda _, dot: ((dot[0] + 1, ShortKey("efef")),),
        )

        # init chain
        chain_id = self.chain_id
        self.dbms.chains[chain_id] = MockChain()
        frontier_diff = FrontierDiff(
            (), {(10, ShortKey("efef")): {2: ("ef1",), 7: ("ef2",)}}
        )

        set_to_request = set()
        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            chain_id, frontier_diff, set_to_request
        )
        assert len(set_to_request) == 1
        assert len(blobs) == 1

    def test_blocks_by_frontier_diff_no_seq_num(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore, "get_hash_by_dot", lambda _, dot_bytes: self.test_hash
        )
        monkeypatch.setattr(
            MockBlockStore, "get_block_by_hash", lambda _, blob_hash: self.block_blob
        )
        monkeypatch.setattr(MockChain, "get_dots_by_seq_num", lambda _, seq_num: list())

        # init chain
        chain_id = self.chain_id
        self.dbms.chains[chain_id] = MockChain()
        frontier_diff = FrontierDiff(Ranges(((1, 2),)), {})

        set_to_request = set()
        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            chain_id, frontier_diff, set_to_request
        )
        assert len(set_to_request) == 0
        assert len(list(blobs)) == 0

    def test_blocks_by_frontier_diff_no_chain(self, monkeypatch, std_vals):
        monkeypatch.setattr(
            MockBlockStore, "get_hash_by_dot", lambda _, dot_bytes: self.test_hash
        )
        monkeypatch.setattr(
            MockBlockStore, "get_block_by_hash", lambda _, blob_hash: self.block_blob
        )
        monkeypatch.setattr(
            MockChain, "get_dots_by_seq_num", lambda _, seq_num: list("dot1")
        )

        # init chain
        chain_id = self.chain_id
        # self.dbms.chains[chain_id] = MockChain()
        frontier_diff = FrontierDiff(Ranges(((1, 1),)), {})

        set_to_request = set()
        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            chain_id, frontier_diff, set_to_request
        )
        assert len(set_to_request) == 0
        assert len(list(blobs)) == 0
Exemple #7
0
class BamiCommunity(
    Community,
    BlockSyncMixin,
    SubComGossipMixin,
    SubCommunityMixin,
    BaseSubCommunityFactory,
    SubCommunityDiscoveryStrategy,
    metaclass=ABCMeta,
):
    """
    Community for secure backbone.
    """

    master_peer = Peer(
        unhexlify(
            "4c69624e61434c504b3a062780beaeb40e70fca4cfc1b7751d734f361cf8d815db24dbb8a99fc98af4"
            "39fc977d84f71a431f8825ba885a5cf86b2498c6b473f33dd20dbdcffd199048fc"
        )
    )
    version = b"\x02"

    async def flex_runner(
        self,
        delay: Callable[[], float],
        interval: Callable[[], float],
        task: Callable,
        *args: List
    ) -> None:
        await sleep(delay())
        while True:
            await task(*args)
            await sleep(interval())

    def register_flexible_task(
        self,
        name: str,
        task: Callable,
        *args: List,
        delay: Callable = None,
        interval: Callable = None
    ) -> Union[Future, Task]:
        """
        Register a Task/(coroutine)function so it can be canceled at shutdown time or by name.
        """
        if not delay:

            def delay():
                return random.random()

        if not interval:

            def interval():
                return random.random()

        task = task if iscoroutinefunction(task) else coroutine(task)
        return self.register_task(
            name, ensure_future(self.flex_runner(delay, interval, task, *args))
        )

    def __init__(
        self,
        my_peer: Peer,
        endpoint: Any,
        network: Network,
        ipv8: Optional[IPv8] = None,
        max_peers: int = None,
        anonymize: bool = False,
        db: BaseDB = None,
        work_dir: str = None,
        settings: BamiSettings = None,
        **kwargs
    ):
        """

        Args:
            my_peer:
            endpoint:
            network:
            max_peers:
            anonymize:
            db:
        """
        if not settings:
            self._settings = BamiSettings()
        else:
            self._settings = settings

        if not work_dir:
            work_dir = self.settings.work_directory
        if not db:
            self._persistence = DBManager(ChainFactory(), LMDBLockStore(work_dir))
        else:
            self._persistence = db
        if not max_peers:
            max_peers = self.settings.main_max_peers
        self._ipv8 = ipv8
        super(BamiCommunity, self).__init__(
            my_peer, endpoint, network, max_peers, anonymize=anonymize
        )

        self._logger = logging.getLogger(self.__class__.__name__)
        # Create DB Manager

        self.logger.debug(
            "The Plexus community started with Public Key: %s",
            hexlify(self.my_peer.public_key.key_to_bin()),
        )
        self.relayed_broadcasts = set()

        self.shutting_down = False

        # Sub-Communities logic
        self.my_subscriptions = dict()

        self.peer_subscriptions = (
            dict()
        )  # keeps track of which communities each peer is part of
        self.bootstrap_master = None

        self.periodic_sync_lc = {}

        self.incoming_queues = {}
        self.processing_queue_tasks = {}

        self.ordered_notifier = Notifier()
        self.unordered_notifier = Notifier()

        # Setup and add message handlers
        for base in BamiCommunity.__bases__:
            if issubclass(base, MessageStateMachine):
                base.setup_messages(self)

        self.add_message_handler(SubscriptionsPayload, self.received_peer_subs)

    # ----- Discovery start -----
    def start_discovery(
        self,
        target_peers: int = None,
        discovery_algorithm: Union[Type[RandomWalk], Type[EdgeWalk]] = RandomWalk,
        discovery_params: Dict[str, Any] = None,
    ):

        if not self._ipv8:
            raise IPv8UnavailableException("Cannot start discovery at main community")

        discovery = (
            discovery_algorithm(self)
            if not discovery_params
            else discovery_algorithm(self, **discovery_params)
        )
        if not target_peers:
            target_peers = self.settings.main_min_peers
        self._ipv8.add_strategy(self, discovery, target_peers)

    # ----- Update notifiers for new blocks ------------

    def get_block_and_blob_by_dot(
        self, chain_id: bytes, dot: Dot
    ) -> Tuple[bytes, BamiBlock]:
        """Get blob and serialized block and by the chain_id and dot.
        Can raise DatabaseDesynchronizedException if no block found."""
        blk_blob = self.persistence.get_block_blob_by_dot(chain_id, dot)
        if not blk_blob:
            raise DatabaseDesynchronizedException(
                "Block is not found in db: {chain_id}, {dot}".format(
                    chain_id=chain_id, dot=dot
                )
            )
        block = BamiBlock.unpack(blk_blob, self.serializer)
        return blk_blob, block

    def get_block_by_dot(self, chain_id: bytes, dot: Dot) -> BamiBlock:
        """Get block by the chain_id and dot. Can raise DatabaseDesynchronizedException"""
        return self.get_block_and_blob_by_dot(chain_id, dot)[1]

    def block_notify(self, chain_id: bytes, dots: List[Dot]):
        self.logger.info("Processing dots %s on chain: %s", dots, chain_id)
        for dot in dots:
            block = self.get_block_by_dot(chain_id, dot)
            self.ordered_notifier.notify(chain_id, block)

    def subscribe_in_order_block(
        self, topic: Union[bytes, ChainTopic], callback: Callable[[BamiBlock], None]
    ):
        """Subscribe on block updates received in-order. Callable will receive the block."""
        self._persistence.add_unique_observer(topic, self.block_notify)
        self.ordered_notifier.add_observer(topic, callback)

    def subscribe_out_order_block(
        self, topic: Union[bytes, ChainTopic], callback: Callable[[BamiBlock], None]
    ):
        """Subscribe on block updates received in-order. Callable will receive the block."""
        self.unordered_notifier.add_observer(topic, callback)

    def process_block_unordered(self, blk: BamiBlock, peer: Peer) -> None:
        self.unordered_notifier.notify(blk.com_prefix + blk.com_id, blk)
        if peer != self.my_peer:
            frontier = Frontier(Links((blk.com_dot,)), holes=(), inconsistencies=())
            subcom_id = blk.com_prefix + blk.com_id
            processing_queue = self.incoming_frontier_queue(subcom_id)
            if not processing_queue:
                raise UnknownChainException(
                    "Cannot process block received block with unknown chain. {subcom_id}".format(
                        subcom_id=subcom_id
                    )
                )
            processing_queue.put_nowait((peer, frontier, True))

    # ---- Introduction handshakes => Exchange your subscriptions ----------------
    def create_introduction_request(
        self, socket_address: Any, extra_bytes: bytes = b""
    ):
        extra_bytes = encode_raw(self.my_subcoms)
        return super().create_introduction_request(socket_address, extra_bytes)

    def create_introduction_response(
        self,
        lan_socket_address,
        socket_address,
        identifier,
        introduction=None,
        extra_bytes=b"",
        prefix=None,
    ):
        extra_bytes = encode_raw(self.my_subcoms)
        return super().create_introduction_response(
            lan_socket_address,
            socket_address,
            identifier,
            introduction,
            extra_bytes,
            prefix,
        )

    def introduction_response_callback(self, peer, dist, payload):
        subcoms = decode_raw(payload.extra_bytes)
        self.process_peer_subscriptions(peer, subcoms)
        # TODO: add subscription strategy
        if self.settings.track_neighbours_chains:
            self.subscribe_to_subcom(peer.public_key.key_to_bin())

    def introduction_request_callback(self, peer, dist, payload):
        subcoms = decode_raw(payload.extra_bytes)
        self.process_peer_subscriptions(peer, subcoms)
        # TODO: add subscription strategy
        if self.settings.track_neighbours_chains:
            self.subscribe_to_subcom(peer.public_key.key_to_bin())

    # ----- Community routines ------

    async def unload(self):
        self.logger.debug("Unloading the Plexus Community.")
        self.shutting_down = True
        for mid in self.processing_queue_tasks:
            if not self.processing_queue_tasks[mid].done():
                self.processing_queue_tasks[mid].cancel()
        for subcom_id in self.my_subscriptions:
            await self.my_subscriptions[subcom_id].unload()
        await super(BamiCommunity, self).unload()

        # Close the persistence layer
        self.persistence.close()

    @property
    def settings(self) -> BamiSettings:
        return self._settings

    @property
    def persistence(self) -> BaseDB:
        return self._persistence

    @property
    def my_pub_key_bin(self) -> bytes:
        return self.my_peer.public_key.key_to_bin()

    def send_packet(self, peer: Peer, packet: Any, sig: bool = True) -> None:
        self.ez_send(peer, packet, sig=sig)

    @property
    def my_peer_key(self) -> Key:
        return self.my_peer.key

    # ----- SubCommunity routines ------
    def get_subcom_discovery_strategy(
        self, subcom_id: bytes
    ) -> Union[SubCommunityDiscoveryStrategy, Type[SubCommunityDiscoveryStrategy]]:
        return self

    @property
    def subcom_factory(
        self,
    ) -> Union[BaseSubCommunityFactory, Type[BaseSubCommunityFactory]]:
        return self

    @property
    def my_subcoms(self) -> Iterable[bytes]:
        return list(self.my_subscriptions.keys())

    def get_subcom(self, subcom_id: bytes) -> Optional[BaseSubCommunity]:
        return self.my_subscriptions.get(subcom_id)

    def add_subcom(self, sub_com: bytes, subcom_obj: BaseSubCommunity) -> None:
        if not subcom_obj:
            raise SubCommunityEmptyException("Sub-Community object is none", sub_com)
        self.my_subscriptions[sub_com] = subcom_obj

    def discovered_peers_by_subcom(self, subcom_id: bytes) -> Iterable[Peer]:
        return self.peer_subscriptions.get(subcom_id, [])

    def process_peer_subscriptions(self, peer: Peer, subcoms: List[bytes]) -> None:
        for c in subcoms:
            # For each sub-community that is also known to me - introduce peer.
            if c in self.my_subscriptions:
                self.my_subscriptions[c].add_peer(peer)
            # Keep all sub-communities and peer in a map
            if c not in self.peer_subscriptions:
                self.peer_subscriptions[c] = set()
            self.peer_subscriptions[c].add(peer)

    @lazy_wrapper(SubscriptionsPayload)
    def received_peer_subs(self, peer: Peer, payload: SubscriptionsPayload) -> None:
        subcoms = decode_raw(payload.subcoms)
        self.process_peer_subscriptions(peer, subcoms)

    def notify_peers_on_new_subcoms(self) -> None:
        for peer in self.get_peers():
            self.send_packet(
                peer,
                SubscriptionsPayload(self.my_pub_key_bin, encode_raw(self.my_subcoms)),
            )

    # -------- Community block sharing  -------------

    def start_gossip_sync(
        self,
        subcom_id: bytes,
        prefix: bytes = b"",
        delay: Callable[[], float] = None,
        interval: Callable[[], float] = None,
    ) -> None:
        full_com_id = prefix + subcom_id
        self.logger.debug("Starting gossip with frontiers on chain %s", full_com_id)
        self.periodic_sync_lc[full_com_id] = self.register_flexible_task(
            "gossip_sync_" + str(full_com_id),
            self.gossip_sync_task,
            subcom_id,
            prefix,
            delay=delay
            if delay
            else lambda: random.random() * self._settings.gossip_sync_max_delay,
            interval=interval if interval else lambda: self._settings.gossip_interval,
        )
        self.incoming_queues[full_com_id] = Queue()
        self.processing_queue_tasks[full_com_id] = ensure_future(
            self.process_frontier_queue(full_com_id)
        )

    def incoming_frontier_queue(self, subcom_id: bytes) -> Optional[Queue]:
        return self.incoming_queues.get(subcom_id)

    def get_peer_by_key(
        self, peer_key: bytes, subcom_id: bytes = None
    ) -> Optional[Peer]:
        if subcom_id:
            subcom_peers = self.get_subcom(subcom_id).get_known_peers()
            for peer in subcom_peers:
                if peer.public_key.key_to_bin() == peer_key:
                    return peer
        for peer in self.get_peers():
            if peer.public_key.key_to_bin() == peer_key:
                return peer
        return None

    def choose_community_peers(
        self, com_peers: Iterable[Peer], current_seed: Any, commitee_size: int
    ) -> Iterable[Peer]:
        rand = random.Random(current_seed)
        return rand.sample(com_peers, min(commitee_size, len(com_peers)))

    def share_in_community(
        self,
        block: Union[BamiBlock, bytes],
        subcom_id: bytes = None,
        ttl: int = None,
        fanout: int = None,
        seed: Any = None,
    ) -> None:
        """
        Share block in sub-community via push-based gossip.
        Args:
            block: PlexusBlock to share
            subcom_id: identity of the sub-community, if not specified the main community connections will be used.
            ttl: ttl of the gossip, if not specified - default settings will be used
            fanout: of the gossip, if not specified - default settings will be used
            seed: seed for the peers selection, otherwise random value will be used
        """
        if not subcom_id or not self.get_subcom(subcom_id):
            subcom_peers = self.get_peers()
        else:
            subcom_peers = self.get_subcom(subcom_id).get_known_peers()
        if not seed:
            seed = random.random()
        if not fanout:
            fanout = self.settings.push_gossip_fanout
        if not ttl:
            ttl = self.settings.push_gossip_ttl
        if subcom_peers:
            selected_peers = self.choose_community_peers(subcom_peers, seed, fanout)
            self.send_block(block, selected_peers, ttl)

    # ------ Audits for the chain wrp to invariants -----
    @abstractmethod
    def witness_tx_well_formatted(self, witness_tx: Any) -> bool:
        """
        Returns:
            False if bad format
        """
        pass

    @abstractmethod
    def build_witness_blob(self, chain_id: bytes, seq_num: int) -> Optional[bytes]:
        """
        Args:
            chain_id: bytes identifier of the chain
            seq_num: of the chain to audit to
        Returns:
            witness blob (bytes) if possible, None otherwise
        """
        pass

    @abstractmethod
    def apply_witness_tx(self, block: BamiBlock, witness_tx: Any) -> None:
        pass

    def verify_witness_transaction(self, chain_id: bytes, witness_tx: Any) -> None:
        """
        Verify the witness transaction
        Raises:
             InvalidFormatTransaction
        """
        # 1. Witness transaction ill-formatted
        if not self.witness_tx_well_formatted(witness_tx):
            raise InvalidTransactionFormatException(
                "Invalid witness transaction", chain_id, witness_tx
            )

    def witness(self, chain_id: bytes, seq_num: int) -> None:
        """
        Witness the chain up to a sequence number.
        If chain is known and data exists:
         - Will create a witness block, link to latest known blocks and share in the community.
        Otherwise:
         - Do nothing
        Args:
            chain_id: id of the chain
            seq_num: sequence number of the block:
        """
        chain = self.persistence.get_chain(chain_id)
        if chain:
            witness_blob = self.build_witness_blob(chain_id, seq_num)
            if witness_blob:
                blk = self.create_signed_block(
                    block_type=WITNESS_TYPE,
                    transaction=witness_blob,
                    prefix=b"w",
                    com_id=chain_id,
                    use_consistent_links=False,
                )
                self.logger.debug(
                    "Creating witness block on chain %s: %s, com_dot %s, pers_dot %s",
                    shorten(blk.com_id),
                    seq_num,
                    blk.com_dot,
                    blk.pers_dot,
                )
                self.share_in_community(blk, chain_id)

    def process_witness(self, block: BamiBlock) -> None:
        """Process received witness transaction"""
        witness_tx = self.unpack_witness_blob(block.transaction)
        chain_id = block.com_id
        self.verify_witness_transaction(chain_id, witness_tx)
        # Apply to db
        self.apply_witness_tx(block, witness_tx)

    def unpack_witness_blob(self, witness_blob: bytes) -> Any:
        """
        Returns:
            decoded witness transaction
        """
        return decode_raw(witness_blob)

    # ------ Confirm and reject functions --------------
    def confirm(self, block: BamiBlock, extra_data: Dict = None) -> None:
        """Create confirm block linked to block. Link will be in the transaction with block dot.
           Add extra data to the transaction with a 'extra_data' dictionary.
        """
        chain_id = block.com_id if block.com_id != EMPTY_PK else block.public_key
        dot = block.com_dot if block.com_id != EMPTY_PK else block.pers_dot
        confirm_tx = {b"initiator": block.public_key, b"dot": dot}
        if extra_data:
            confirm_tx.update(extra_data)
        block = self.create_signed_block(
            block_type=CONFIRM_TYPE, transaction=encode_raw(confirm_tx), com_id=chain_id
        )
        self.share_in_community(block, chain_id)

    def verify_confirm_tx(self, claimer: bytes, confirm_tx: Dict) -> None:
        # 1. verify claim format
        if not confirm_tx.get(b"initiator") or not confirm_tx.get(b"dot"):
            raise InvalidTransactionFormatException(
                "Invalid claim ", claimer, confirm_tx
            )

    def process_confirm(self, block: BamiBlock) -> None:
        confirm_tx = decode_raw(block.transaction)
        self.verify_confirm_tx(block.public_key, confirm_tx)
        self.apply_confirm_tx(block, confirm_tx)

    @abstractmethod
    def apply_confirm_tx(self, block: BamiBlock, confirm_tx: Dict) -> None:
        pass

    def reject(self, block: BamiBlock, extra_data: Dict = None) -> None:
        # change it to confirm
        # create claim block and share in the community
        chain_id = block.com_id if block.com_id != EMPTY_PK else block.public_key
        dot = block.com_dot if block.com_id != EMPTY_PK else block.pers_dot
        reject_tx = {b"initiator": block.public_key, b"dot": dot}
        if extra_data:
            reject_tx.update(extra_data)
        block = self.create_signed_block(
            block_type=REJECT_TYPE, transaction=encode_raw(reject_tx), com_id=chain_id
        )
        self.share_in_community(block, chain_id)

    def verify_reject_tx(self, rejector: bytes, confirm_tx: Dict) -> None:
        # 1. verify reject format
        if not confirm_tx.get(b"initiator") or not confirm_tx.get(b"dot"):
            raise InvalidTransactionFormatException(
                "Invalid reject ", rejector, confirm_tx
            )

    def process_reject(self, block: BamiBlock) -> None:
        reject_tx = decode_raw(block.transaction)
        self.verify_reject_tx(block.public_key, reject_tx)
        self.apply_reject_tx(block, reject_tx)

    @abstractmethod
    def apply_reject_tx(self, block: BamiBlock, reject_tx: Dict) -> None:
        pass

    @abstractmethod
    def block_response(
        self, block: BamiBlock, wait_time: float = None, wait_blocks: int = None
    ) -> BlockResponse:
        """
        Respond to block BlockResponse: Reject, Confirm, Delay
        Args:
            block: to respond to
            wait_time: time that passed since first block process initiated
            wait_blocks: number of blocks passed since the block
        Returns:
            BlockResponse: Confirm, Reject or Delay
        """
        pass
Exemple #8
0
    def __init__(
        self,
        my_peer: Peer,
        endpoint: Any,
        network: Network,
        ipv8: Optional[IPv8] = None,
        max_peers: int = None,
        anonymize: bool = False,
        db: BaseDB = None,
        work_dir: str = None,
        settings: BamiSettings = None,
        **kwargs
    ):
        """

        Args:
            my_peer:
            endpoint:
            network:
            max_peers:
            anonymize:
            db:
        """
        if not settings:
            self._settings = BamiSettings()
        else:
            self._settings = settings

        if not work_dir:
            work_dir = self.settings.work_directory
        if not db:
            self._persistence = DBManager(ChainFactory(), LMDBLockStore(work_dir))
        else:
            self._persistence = db
        if not max_peers:
            max_peers = self.settings.main_max_peers
        self._ipv8 = ipv8
        super(BamiCommunity, self).__init__(
            my_peer, endpoint, network, max_peers, anonymize=anonymize
        )

        self._logger = logging.getLogger(self.__class__.__name__)
        # Create DB Manager

        self.logger.debug(
            "The Plexus community started with Public Key: %s",
            hexlify(self.my_peer.public_key.key_to_bin()),
        )
        self.relayed_broadcasts = set()

        self.shutting_down = False

        # Sub-Communities logic
        self.my_subscriptions = dict()

        self.peer_subscriptions = (
            dict()
        )  # keeps track of which communities each peer is part of
        self.bootstrap_master = None

        self.periodic_sync_lc = {}

        self.incoming_queues = {}
        self.processing_queue_tasks = {}

        self.ordered_notifier = Notifier()
        self.unordered_notifier = Notifier()

        # Setup and add message handlers
        for base in BamiCommunity.__bases__:
            if issubclass(base, MessageStateMachine):
                base.setup_messages(self)

        self.add_message_handler(SubscriptionsPayload, self.received_peer_subs)