Ejemplo n.º 1
0
def test_hot_start_db(tmpdir):
    tmp_val = tmpdir
    block_store = LMDBLockStore(str(tmp_val))
    chain_factory = ChainFactory()
    dbms = DBManager(chain_factory, block_store)

    test_block = FakeBlock()
    packed_block = test_block.pack()
    dbms.add_block(packed_block, test_block)
    tx_blob = test_block.transaction

    assert dbms.get_tx_blob_by_dot(test_block.com_id, test_block.com_dot) == tx_blob
    assert (
        dbms.get_block_blob_by_dot(test_block.com_id, test_block.com_dot)
        == packed_block
    )
    front = dbms.get_chain(test_block.com_id).frontier
    dbms.close()

    block_store2 = LMDBLockStore(str(tmp_val))
    chain_factory2 = ChainFactory()
    dbms2 = DBManager(chain_factory2, block_store2)

    assert dbms2.get_tx_blob_by_dot(test_block.com_id, test_block.com_dot) == tx_blob
    assert (
        dbms2.get_block_blob_by_dot(test_block.com_id, test_block.com_dot)
        == packed_block
    )

    assert dbms2.get_chain(test_block.com_id).frontier == front

    dbms2.close()
    tmp_val.remove()
Ejemplo n.º 2
0
 def test_block_payload(self):
     blk = FakeBlock()
     blk_bytes = blk.pack()
     unpacked = blk.serializer.ez_unpack_serializables([BlockPayload],
                                                       blk_bytes)
     blk2 = BamiBlock.from_payload(unpacked[0])
     assert blk2 == blk
Ejemplo n.º 3
0
 def test_sign(self):
     """
     Test signing a block and whether the signature is valid
     """
     crypto = default_eccrypto
     block = FakeBlock()
     assert crypto.is_valid_signature(
         block.key, block.pack(signature=False), block.signature
     )
Ejemplo n.º 4
0
async def test_send_receive_raw_block(monkeypatch, mocker, set_vals):
    blk = FakeBlock(transaction=b"test")
    set_vals.nodes[0].overlay.send_block(
        blk.pack(), [set_vals.nodes[1].overlay.my_peer]
    )
    monkeypatch.setattr(MockDBManager, "add_block", lambda _, __, ___, prefix: None)
    monkeypatch.setattr(MockDBManager, "has_block", lambda _, __: False)
    spy = mocker.spy(MockDBManager, "has_block")
    await deliver_messages()
    spy.assert_called_with(ANY, blk.hash)
Ejemplo n.º 5
0
async def test_send_incorrect_block(monkeypatch, mocker, set_vals):
    blk = FakeBlock(transaction=b"test")
    set_vals.nodes[0].overlay.send_block(blk.pack(signature=False),
                                         [set_vals.nodes[1].overlay.my_peer])
    spy = mocker.spy(BamiBlock, "unpack")
    spy2 = mocker.spy(BamiBlock, "block_invariants_valid")
    await deliver_messages()
    spy.assert_called_once()
    spy2.assert_called_once()
    assert spy2.spy_return is False
Ejemplo n.º 6
0
 def test_pack_unpack(self):
     blk = FakeBlock()
     blk_bytes = blk.pack()
     blk2 = BamiBlock.unpack(blk_bytes, blk.serializer)
     assert blk == blk2
Ejemplo n.º 7
0
class TestIntegrationDBManager:
    @pytest.fixture(autouse=True)
    def setUp(self, tmpdir) -> None:
        tmp_val = tmpdir
        self.block_store = LMDBLockStore(str(tmp_val))
        self.chain_factory = ChainFactory()
        self.dbms = DBManager(self.chain_factory, self.block_store)
        yield
        self.dbms.close()

    @pytest.fixture(autouse=True)
    def setUp2(self, tmpdir) -> None:
        tmp_val = tmpdir
        self.block_store2 = LMDBLockStore(str(tmp_val))
        self.chain_factory2 = ChainFactory()
        self.dbms2 = DBManager(self.chain_factory2, self.block_store2)
        yield
        try:
            self.dbms2.close()
            tmp_val.remove()
        except FileNotFoundError:
            pass

    def test_get_tx_blob(self):
        self.test_block = FakeBlock()
        packed_block = self.test_block.pack()
        self.dbms.add_block(packed_block, self.test_block)
        self.tx_blob = self.test_block.transaction

        assert (
            self.dbms.get_tx_blob_by_dot(
                self.test_block.com_id, self.test_block.com_dot
            )
            == self.tx_blob
        )
        assert (
            self.dbms.get_block_blob_by_dot(
                self.test_block.com_id, self.test_block.com_dot
            )
            == packed_block
        )

    def test_add_notify_block_one_chain(self, create_batches, insert_function):
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                assert (len(self.val_dots) == 0 and dot[0] == 1) or dot[
                    0
                ] == self.val_dots[-1][0] + 1
                self.val_dots.append(dot)

        blks = create_batches(num_batches=1, num_blocks=100)
        com_id = blks[0][0].com_id
        self.dbms.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0]))
        assert len(self.val_dots) == 100

    def test_add_notify_block_with_conflicts(self, create_batches, insert_function):
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                self.val_dots.append(dot)

        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id
        self.dbms.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0][:20]))
        wrap_iterate(insert_function(self.dbms, blks[1][:40]))
        wrap_iterate(insert_function(self.dbms, blks[0][20:60]))
        wrap_iterate(insert_function(self.dbms, blks[1][40:]))
        wrap_iterate(insert_function(self.dbms, blks[0][60:]))

        assert len(self.val_dots) == 200

    def test_blocks_by_frontier_diff(self, create_batches, insert_function):
        # init chain
        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id

        wrap_iterate(insert_function(self.dbms, blks[0][:50]))
        wrap_iterate(insert_function(self.dbms2, blks[1][:50]))

        front = self.dbms.get_chain(com_id).frontier
        front_diff = self.dbms2.get_chain(com_id).reconcile(front)
        vals_request = set()

        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            com_id, front_diff, vals_request
        )
        assert len(blobs) == 41

    def reconcile_round(self, com_id):
        front = self.dbms.get_chain(com_id).frontier
        front_diff = self.dbms2.get_chain(com_id).reconcile(front)
        vals_request = set()
        blobs = self.dbms.get_block_blobs_by_frontier_diff(
            com_id, front_diff, vals_request
        )
        return blobs

    def test_blocks_by_fdiff_with_holes(self, create_batches, insert_function):
        # init chain
        blks = create_batches(num_batches=2, num_blocks=100)
        com_id = blks[0][0].com_id
        self.val_dots = []

        def chain_dots_tester(chain_id, dots):
            for dot in dots:
                self.val_dots.append(dot)

        self.dbms2.add_observer(com_id, chain_dots_tester)

        wrap_iterate(insert_function(self.dbms, blks[0][:50]))
        wrap_iterate(insert_function(self.dbms2, blks[1][:20]))
        wrap_iterate(insert_function(self.dbms2, blks[1][40:60]))

        assert len(self.val_dots) == 20
        blobs = self.reconcile_round(com_id)
        assert len(blobs) == 41

        for b in blobs:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))

        assert len(self.val_dots) == 20
        blobs2 = self.reconcile_round(com_id)
        assert len(blobs2) == 8
        for b in blobs2:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))

        assert len(self.val_dots) == 20
        blobs2 = self.reconcile_round(com_id)
        assert len(blobs2) == 1
        for b in blobs2:
            self.dbms2.add_block(b, FakeBlock.unpack(b, blks[0][0].serializer))
        assert len(self.val_dots) == 70