示例#1
0
    def test_requires_genesis_fails_if_joins_network_with_file(self):
        """
        In this case, when there is
         - a genesis_batch_file
         - network id
        the validator should produce an assertion error, as it is joining
        a network, and not a genesis node.
        """
        self._with_empty_batch_file()
        self._with_network_name('some_block_chain_id')

        block_store = self.make_block_store()
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        genesis_ctrl = GenesisController(
            context_manager=Mock('context_manager'),
            transaction_executor=Mock('txn_executor'),
            completer=Mock('completer'),
            block_store=block_store,
            state_view_factory=Mock('StateViewFactory'),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        with self.assertRaises(InvalidGenesisStateError):
            genesis_ctrl.requires_genesis()
示例#2
0
    def test_requires_genesis_fails_if_block_exists(self):
        """
        In this case, when there is
         - a genesis_batch_file
         - a chain head id
        the validator should produce an assertion, as it already has a
        genesis block and should not attempt to produce another.
        """
        self._with_empty_batch_file()

        block = self._create_block()
        block_store = self.make_block_store({block.header_signature: block})
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        genesis_ctrl = GenesisController(
            context_manager=Mock('context_manager'),
            transaction_executor=Mock('txn_executor'),
            completer=Mock('completer'),
            block_store=block_store,
            state_view_factory=Mock('StateViewFactory'),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        with self.assertRaises(InvalidGenesisStateError):
            genesis_ctrl.requires_genesis()
示例#3
0
    def test_does_not_require_genesis_with_no_file_no_network(self):
        """
        In this case, when there is:
         - no genesis.batch file
         - no chain head
         - no network
        the the GenesisController should not require genesis.
        """
        block_store = self.make_block_store()
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        genesis_ctrl = GenesisController(
            context_manager=Mock('context_manager'),
            transaction_executor=Mock('txn_executor'),
            completer=Mock('completer'),
            block_store=block_store,
            state_view_factory=Mock('StateViewFactory'),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        self.assertEqual(False, genesis_ctrl.requires_genesis())
示例#4
0
    def test_empty_batch_file_should_produce_block(self,
                                                   mock_scheduler_complete):
        """
        In this case, the genesis batch, even with an empty list of batches,
        should produce a genesis block.
        Also:
         - the genesis.batch file should be deleted
         - the block_chain_id file should be created and populated
        """
        genesis_file = self._with_empty_batch_file()
        block_store = self.make_block_store()
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        state_database = NativeLmdbDatabase(
            os.path.join(self._temp_dir, 'test_genesis.lmdb'),
            indexes=MerkleDatabase.create_index_configuration(),
            _size=10 * 1024 * 1024)
        merkle_db = MerkleDatabase(state_database)

        ctx_mgr = Mock(name='ContextManager')
        ctx_mgr.get_squash_handler.return_value = Mock()
        ctx_mgr.get_first_root.return_value = merkle_db.get_merkle_root()

        txn_executor = Mock(name='txn_executor')
        completer = Mock('completer')
        completer.add_block = Mock('add_block')

        genesis_ctrl = GenesisController(
            context_manager=ctx_mgr,
            transaction_executor=txn_executor,
            completer=completer,
            block_store=block_store,
            state_view_factory=StateViewFactory(state_database),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        on_done_fn = Mock(return_value='')
        genesis_ctrl.start(on_done_fn)

        self.assertEqual(False, os.path.exists(genesis_file))

        self.assertEqual(True, block_store.chain_head is not None)
        self.assertEqual(1, on_done_fn.call_count)
        self.assertEqual(1, completer.add_block.call_count)
        self.assertEqual(block_store.chain_head.identifier,
                         self._read_block_chain_id())
示例#5
0
    def test_does_not_require_genesis_block_exists(self):
        block = self._create_block()
        block_store = self.make_block_store({block.header_signature: block})
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        genesis_ctrl = GenesisController(
            context_manager=Mock('context_manager'),
            transaction_executor=Mock('txn_executor'),
            completer=Mock('completer'),
            block_store=block_store,
            state_view_factory=Mock('StateViewFactory'),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        self.assertEqual(False, genesis_ctrl.requires_genesis())
示例#6
0
    def test_requires_genesis(self):
        self._with_empty_batch_file()

        block_store = self.make_block_store()
        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        genesis_ctrl = GenesisController(
            context_manager=Mock('context_manager'),
            transaction_executor=Mock('txn_executor'),
            completer=Mock('completer'),
            block_store=block_store,  # Empty block store
            state_view_factory=Mock('StateViewFactory'),
            identity_signer=self._signer,
            block_manager=block_manager,
            data_dir=self._temp_dir,
            config_dir=self._temp_dir,
            chain_id_manager=ChainIdManager(self._temp_dir),
            batch_sender=Mock('batch_sender'))

        self.assertEqual(True, genesis_ctrl.requires_genesis())
示例#7
0
    def __init__(self,
                 bind_network,
                 bind_component,
                 bind_consensus,
                 endpoint,
                 peering,
                 seeds_list,
                 peer_list,
                 data_dir,
                 config_dir,
                 identity_signer,
                 scheduler_type,
                 permissions,
                 minimum_peer_connectivity,
                 maximum_peer_connectivity,
                 state_pruning_block_depth,
                 fork_cache_keep_time,
                 network_public_key=None,
                 network_private_key=None,
                 roles=None):
        """Constructs a validator instance.

        Args:
            bind_network (str): the network endpoint
            bind_component (str): the component endpoint
            endpoint (str): the zmq-style URI of this validator's
                publically reachable endpoint
            peering (str): The type of peering approach. Either 'static'
                or 'dynamic'. In 'static' mode, no attempted topology
                buildout occurs -- the validator only attempts to initiate
                peering connections with endpoints specified in the
                peer_list. In 'dynamic' mode, the validator will first
                attempt to initiate peering connections with endpoints
                specified in the peer_list and then attempt to do a
                topology buildout starting with peer lists obtained from
                endpoints in the seeds_list. In either mode, the validator
                will accept incoming peer requests up to max_peers.
            seeds_list (list of str): a list of addresses to connect
                to in order to perform the initial topology buildout
            peer_list (list of str): a list of peer addresses
            data_dir (str): path to the data directory
            config_dir (str): path to the config directory
            identity_signer (str): cryptographic signer the validator uses for
                signing
        """
        # -- Setup Global State Database and Factory -- #
        global_state_db_filename = os.path.join(
            data_dir, 'merkle-{}.lmdb'.format(bind_network[-2:]))
        LOGGER.debug('global state database file is %s',
                     global_state_db_filename)
        global_state_db = NativeLmdbDatabase(
            global_state_db_filename,
            indexes=MerkleDatabase.create_index_configuration())
        state_view_factory = StateViewFactory(global_state_db)
        native_state_view_factory = NativeStateViewFactory(global_state_db)

        # -- Setup Receipt Store -- #
        receipt_db_filename = os.path.join(
            data_dir, 'txn_receipts-{}.lmdb'.format(bind_network[-2:]))
        LOGGER.debug('txn receipt store file is %s', receipt_db_filename)
        receipt_db = LMDBNoLockDatabase(receipt_db_filename, 'c')
        receipt_store = TransactionReceiptStore(receipt_db)

        # -- Setup Block Store -- #
        block_db_filename = os.path.join(
            data_dir, 'block-{}.lmdb'.format(bind_network[-2:]))
        LOGGER.debug('block store file is %s', block_db_filename)
        block_db = IndexedDatabase(
            block_db_filename,
            BlockStore.serialize_block,
            BlockStore.deserialize_block,
            flag='c',
            indexes=BlockStore.create_index_configuration())
        block_store = BlockStore(block_db)
        # The cache keep time for the journal's block cache must be greater
        # than the cache keep time used by the completer.
        base_keep_time = 1200

        block_manager = BlockManager()
        block_manager.add_store("commit_store", block_store)

        block_status_store = BlockValidationResultStore()

        # -- Setup Thread Pools -- #
        component_thread_pool = InstrumentedThreadPoolExecutor(
            max_workers=10, name='Component')
        network_thread_pool = InstrumentedThreadPoolExecutor(max_workers=10,
                                                             name='Network')
        client_thread_pool = InstrumentedThreadPoolExecutor(max_workers=5,
                                                            name='Client')
        sig_pool = InstrumentedThreadPoolExecutor(max_workers=3,
                                                  name='Signature')

        # -- Setup Dispatchers -- #
        component_dispatcher = Dispatcher()
        network_dispatcher = Dispatcher()

        # -- Setup Services -- #
        component_service = Interconnect(bind_component,
                                         component_dispatcher,
                                         secured=False,
                                         heartbeat=False,
                                         max_incoming_connections=20,
                                         monitor=True,
                                         max_future_callback_workers=10)

        zmq_identity = hashlib.sha512(
            time.time().hex().encode()).hexdigest()[:23]

        secure = False
        if network_public_key is not None and network_private_key is not None:
            secure = True

        network_service = Interconnect(bind_network,
                                       dispatcher=network_dispatcher,
                                       zmq_identity=zmq_identity,
                                       secured=secure,
                                       server_public_key=network_public_key,
                                       server_private_key=network_private_key,
                                       heartbeat=True,
                                       public_endpoint=endpoint,
                                       connection_timeout=120,
                                       max_incoming_connections=100,
                                       max_future_callback_workers=10,
                                       authorize=True,
                                       signer=identity_signer,
                                       roles=roles)

        # -- Setup Transaction Execution Platform -- #
        context_manager = ContextManager(global_state_db)

        batch_tracker = BatchTracker(block_store.has_batch)

        settings_cache = SettingsCache(
            SettingsViewFactory(state_view_factory), )

        transaction_executor = TransactionExecutor(
            service=component_service,
            context_manager=context_manager,
            settings_view_factory=SettingsViewFactory(state_view_factory),
            scheduler_type=scheduler_type,
            invalid_observers=[batch_tracker])

        component_service.set_check_connections(
            transaction_executor.check_connections)

        event_broadcaster = EventBroadcaster(component_service, block_store,
                                             receipt_store)

        # -- Consensus Engine -- #
        consensus_thread_pool = InstrumentedThreadPoolExecutor(
            max_workers=3, name='Consensus')
        consensus_dispatcher = Dispatcher()
        consensus_service = Interconnect(bind_consensus,
                                         consensus_dispatcher,
                                         secured=False,
                                         heartbeat=False,
                                         max_incoming_connections=20,
                                         max_future_callback_workers=10)

        consensus_notifier = ConsensusNotifier(consensus_service)

        # -- Setup P2P Networking -- #
        gossip = Gossip(network_service,
                        settings_cache,
                        lambda: block_store.chain_head,
                        block_store.chain_head_state_root,
                        consensus_notifier,
                        endpoint=endpoint,
                        peering_mode=peering,
                        initial_seed_endpoints=seeds_list,
                        initial_peer_endpoints=peer_list,
                        minimum_peer_connectivity=minimum_peer_connectivity,
                        maximum_peer_connectivity=maximum_peer_connectivity,
                        topology_check_frequency=1)

        completer = Completer(
            block_manager=block_manager,
            transaction_committed=block_store.has_transaction,
            get_committed_batch_by_id=block_store.get_batch,
            get_committed_batch_by_txn_id=(
                block_store.get_batch_by_transaction),
            get_chain_head=lambda: unwrap_if_not_none(block_store.chain_head),
            gossip=gossip,
            cache_keep_time=base_keep_time,
            cache_purge_frequency=30,
            requested_keep_time=300)
        self._completer = completer

        block_sender = BroadcastBlockSender(completer, gossip)
        batch_sender = BroadcastBatchSender(completer, gossip)
        chain_id_manager = ChainIdManager(data_dir)

        identity_view_factory = IdentityViewFactory(
            StateViewFactory(global_state_db))

        id_cache = IdentityCache(identity_view_factory)

        # -- Setup Permissioning -- #
        permission_verifier = PermissionVerifier(
            permissions, block_store.chain_head_state_root, id_cache)

        identity_observer = IdentityObserver(to_update=id_cache.invalidate,
                                             forked=id_cache.forked)

        settings_observer = SettingsObserver(
            to_update=settings_cache.invalidate, forked=settings_cache.forked)

        # -- Setup Journal -- #
        batch_injector_factory = DefaultBatchInjectorFactory(
            state_view_factory=state_view_factory, signer=identity_signer)

        block_publisher = BlockPublisher(
            block_manager=block_manager,
            transaction_executor=transaction_executor,
            transaction_committed=block_store.has_transaction,
            batch_committed=block_store.has_batch,
            state_view_factory=native_state_view_factory,
            block_sender=block_sender,
            batch_sender=batch_sender,
            chain_head=block_store.chain_head,
            identity_signer=identity_signer,
            data_dir=data_dir,
            config_dir=config_dir,
            permission_verifier=permission_verifier,
            batch_observers=[batch_tracker],
            batch_injector_factory=batch_injector_factory)

        block_validator = BlockValidator(
            block_manager=block_manager,
            block_store=block_store,
            view_factory=native_state_view_factory,
            transaction_executor=transaction_executor,
            block_status_store=block_status_store,
            permission_verifier=permission_verifier)

        chain_controller = ChainController(
            block_store=block_store,
            block_manager=block_manager,
            block_validator=block_validator,
            state_database=global_state_db,
            chain_head_lock=block_publisher.chain_head_lock,
            block_status_store=block_status_store,
            consensus_notifier=consensus_notifier,
            state_pruning_block_depth=state_pruning_block_depth,
            fork_cache_keep_time=fork_cache_keep_time,
            data_dir=data_dir,
            observers=[
                event_broadcaster, receipt_store, batch_tracker,
                identity_observer, settings_observer
            ])

        genesis_controller = GenesisController(
            context_manager=context_manager,
            transaction_executor=transaction_executor,
            completer=completer,
            block_manager=block_manager,
            block_store=block_store,
            state_view_factory=state_view_factory,
            identity_signer=identity_signer,
            data_dir=data_dir,
            config_dir=config_dir,
            chain_id_manager=chain_id_manager,
            batch_sender=batch_sender)

        responder = Responder(completer)

        completer.set_on_block_received(chain_controller.queue_block)

        self._incoming_batch_sender = None

        # -- Register Message Handler -- #
        network_handlers.add(network_dispatcher, network_service, gossip,
                             completer, responder, network_thread_pool,
                             sig_pool,
                             lambda block_id: block_id in block_manager,
                             self.has_batch, permission_verifier,
                             block_publisher, consensus_notifier)

        component_handlers.add(component_dispatcher, gossip, context_manager,
                               transaction_executor, completer, block_store,
                               batch_tracker, global_state_db,
                               self.get_chain_head_state_root_hash,
                               receipt_store, event_broadcaster,
                               permission_verifier, component_thread_pool,
                               client_thread_pool, sig_pool, block_publisher,
                               identity_signer.get_public_key().as_hex())

        # -- Store Object References -- #
        self._component_dispatcher = component_dispatcher
        self._component_service = component_service
        self._component_thread_pool = component_thread_pool

        self._network_dispatcher = network_dispatcher
        self._network_service = network_service
        self._network_thread_pool = network_thread_pool

        consensus_proxy = ConsensusProxy(
            block_manager=block_manager,
            chain_controller=chain_controller,
            block_publisher=block_publisher,
            gossip=gossip,
            identity_signer=identity_signer,
            settings_view_factory=SettingsViewFactory(state_view_factory),
            state_view_factory=state_view_factory)

        consensus_handlers.add(consensus_dispatcher, consensus_thread_pool,
                               consensus_proxy, consensus_notifier)

        self._block_status_store = block_status_store

        self._consensus_dispatcher = consensus_dispatcher
        self._consensus_service = consensus_service
        self._consensus_thread_pool = consensus_thread_pool

        self._client_thread_pool = client_thread_pool
        self._sig_pool = sig_pool

        self._context_manager = context_manager
        self._transaction_executor = transaction_executor
        self._genesis_controller = genesis_controller
        self._gossip = gossip

        self._block_publisher = block_publisher
        self._block_validator = block_validator
        self._chain_controller = chain_controller
        self._block_validator = block_validator
示例#8
0
class TestCompleter(unittest.TestCase):
    def setUp(self):
        self.block_store = BlockStore(
            DictDatabase(indexes=BlockStore.create_index_configuration()))
        self.block_manager = BlockManager()
        self.block_manager.add_store("commit_store", self.block_store)
        self.gossip = MockGossip()
        self.completer = Completer(
            block_manager=self.block_manager,
            transaction_committed=self.block_store.has_transaction,
            get_committed_batch_by_id=self.block_store.get_batch,
            get_committed_batch_by_txn_id=(
                self.block_store.get_batch_by_transaction),
            get_chain_head=lambda: self.block_store.chain_head,
            gossip=self.gossip)
        self.completer.set_on_block_received(self._on_block_received)
        self.completer.set_on_batch_received(self._on_batch_received)
        self._has_block_value = True

        context = create_context('secp256k1')
        private_key = context.new_random_private_key()
        crypto_factory = CryptoFactory(context)
        self.signer = crypto_factory.new_signer(private_key)

        self.blocks = []
        self.batches = []

    def _on_block_received(self, block_id):
        return self.blocks.append(block_id)

    def _on_batch_received(self, batch):
        return self.batches.append(batch.header_signature)

    def _has_block(self, batch):
        return self._has_block_value

    def _create_transactions(self, count, missing_dep=False):
        txn_list = []

        for _ in range(count):
            payload = {
                'Verb': 'set',
                'Name': 'name' + str(random.randint(0, 100)),
                'Value': random.randint(0, 100)
            }
            intkey_prefix = \
                hashlib.sha512('intkey'.encode('utf-8')).hexdigest()[0:6]

            addr = intkey_prefix + \
                hashlib.sha512(payload["Name"].encode('utf-8')).hexdigest()

            payload_encode = hashlib.sha512(cbor.dumps(payload)).hexdigest()

            header = TransactionHeader(
                signer_public_key=self.signer.get_public_key().as_hex(),
                family_name='intkey',
                family_version='1.0',
                inputs=[addr],
                outputs=[addr],
                dependencies=[],
                batcher_public_key=self.signer.get_public_key().as_hex(),
                payload_sha512=payload_encode)

            if missing_dep:
                header.dependencies.extend(["Missing"])

            header_bytes = header.SerializeToString()

            signature = self.signer.sign(header_bytes)

            transaction = Transaction(header=header_bytes,
                                      payload=cbor.dumps(payload),
                                      header_signature=signature)

            txn_list.append(transaction)

        return txn_list

    def _create_batches(self, batch_count, txn_count, missing_dep=False):

        batch_list = []

        for _ in range(batch_count):
            txn_list = self._create_transactions(txn_count,
                                                 missing_dep=missing_dep)
            txn_sig_list = [txn.header_signature for txn in txn_list]

            batch_header = BatchHeader(
                signer_public_key=self.signer.get_public_key().as_hex())
            batch_header.transaction_ids.extend(txn_sig_list)

            header_bytes = batch_header.SerializeToString()

            signature = self.signer.sign(header_bytes)

            batch = Batch(header=header_bytes,
                          transactions=txn_list,
                          header_signature=signature)

            batch_list.append(batch)

        return batch_list

    def _create_blocks(self,
                       block_count,
                       batch_count,
                       missing_predecessor=False,
                       missing_batch=False,
                       find_batch=True):
        block_list = []

        for i in range(0, block_count):
            batch_list = self._create_batches(batch_count, 2)
            batch_ids = [batch.header_signature for batch in batch_list]

            if missing_predecessor:
                predecessor = "Missing"
            else:
                predecessor = (block_list[i - 1].header_signature
                               if i > 0 else NULL_BLOCK_IDENTIFIER)

            block_header = BlockHeader(
                signer_public_key=self.signer.get_public_key().as_hex(),
                batch_ids=batch_ids,
                block_num=i,
                previous_block_id=predecessor)

            header_bytes = block_header.SerializeToString()

            signature = self.signer.sign(header_bytes)

            if missing_batch:
                if find_batch:
                    self.completer.add_batch(batch_list[-1])
                batch_list = batch_list[:-1]

            block = Block(header=header_bytes,
                          batches=batch_list,
                          header_signature=signature)

            block_list.append(block)

        return block_list

    def test_good_block(self):
        """
        Add completed block to completer. Block should be passed to
        on_block_recieved.
        """
        block = self._create_blocks(1, 1)[0]
        self.completer.add_block(block)
        self.assertIn(block.header_signature, self.blocks)

    def test_duplicate_block(self):
        """
        Submit same block twice.
        """
        block = self._create_blocks(1, 1)[0]
        self.completer.add_block(block)
        self.completer.add_block(block)
        self.assertIn(block.header_signature, self.blocks)
        self.assertEqual(len(self.blocks), 1)

    def test_block_missing_predecessor(self):
        """
        The block is completed but the predecessor is missing.
        """
        block = self._create_blocks(1, 1, missing_predecessor=True)[0]
        self._has_block_value = False
        self.completer.add_block(block)
        self.assertEqual(len(self.blocks), 0)
        self.assertIn("Missing", self.gossip.requested_blocks)
        header = BlockHeader(previous_block_id=NULL_BLOCK_IDENTIFIER)
        missing_block = Block(header_signature="Missing",
                              header=header.SerializeToString())
        self._has_block_value = True
        self.completer.add_block(missing_block)
        self.assertIn(block.header_signature, self.blocks)
        self.assertEqual(block,
                         self.completer.get_block(block.header_signature))

    def test_block_with_extra_batch(self):
        """
        The block has a batch that is not in the batch_id list.
        """
        block = self._create_blocks(1, 1)[0]
        batches = self._create_batches(1, 1, True)
        block.batches.extend(batches)
        self.completer.add_block(block)
        self.assertEqual(len(self.blocks), 0)

    def test_block_missing_batch(self):
        """
        The block is a missing batch and the batch is in the cache. The Block
        will be build and passed to on_block_recieved. This puts the block
        in the self.blocks list.
        """
        block = self._create_blocks(1, 2, missing_batch=True)[0]
        self.completer.add_block(block)
        self.assertIn(block.header_signature, self.blocks)
        self.assertEqual(block,
                         self.completer.get_block(block.header_signature))

    def test_block_missing_batch_not_in_cache(self):
        """
        The block is a missing batch and the batch is not in the cache.
          The batch will be requested and the block will not be passed to
          on_block_recieved.
        """
        block = self._create_blocks(1, 3, missing_batch=True,
                                    find_batch=False)[0]
        self.completer.add_block(block)
        header = BlockHeader()
        header.ParseFromString(block.header)
        self.assertIn(header.batch_ids[-1], self.gossip.requested_batches)

    def test_block_batches_wrong_order(self):
        """
        The block has all of its batches but they are in the wrong order. The
        batches will be reordered and the block will be passed to
        on_block_recieved.
        """
        block = self._create_blocks(1, 6)[0]
        batches = list(block.batches)
        random.shuffle(batches)
        del block.batches[:]
        block.batches.extend(batches)
        self.completer.add_block(block)
        self.assertIn(block.header_signature, self.blocks)

    def test_block_batches_wrong_batch(self):
        """
        The block has all the correct number of batches but one is not in the
        batch_id list. This block should be dropped.
        """
        block = self._create_blocks(1, 6)[0]
        batch = Batch(header_signature="Extra")
        batches = list(block.batches)
        batches[-1] = batch
        block.batches.extend(batches)
        self.completer.add_block(block)
        self.assertEqual(len(self.blocks), 0)

    def test_good_batch(self):
        """
        Add complete batch to completer. The batch should be passed to
        on_batch_received.
        """
        batch = self._create_batches(1, 1)[0]
        self.completer.add_batch(batch)
        self.assertIn(batch.header_signature, self.batches)
        self.assertEqual(batch,
                         self.completer.get_batch(batch.header_signature))

    def test_batch_with_missing_dep(self):
        """
        Add batch to completer that has a missing dependency. The missing
        transaction's batch should be requested add the missing batch is then
        added to the completer. The incomplete batch should be rechecked
        and passed to on_batch_received.
        """
        batch = self._create_batches(1, 1, missing_dep=True)[0]
        self.completer.add_batch(batch)
        self.assertIn("Missing", self.gossip.requested_batches_by_txn_id)

        missing = Transaction(header_signature="Missing")
        missing_batch = Batch(header_signature="Missing_batch",
                              transactions=[missing])
        self.completer.add_batch(missing_batch)
        self.assertIn(missing_batch.header_signature, self.batches)
        self.assertIn(batch.header_signature, self.batches)
        self.assertEqual(missing_batch,
                         self.completer.get_batch_by_transaction("Missing"))
class BlockTreeManager:
    def __str__(self):
        return str(self.block_cache)

    def __repr__(self):
        return repr(self.block_cache)

    def __init__(self, with_genesis=True):
        self.block_sender = MockBlockSender()
        self.batch_sender = MockBatchSender()
        self.block_store = BlockStore(
            DictDatabase(indexes=BlockStore.create_index_configuration()))
        self.block_cache = BlockCache(self.block_store)
        self.dir = tempfile.mkdtemp()
        self.state_db = NativeLmdbDatabase(
            os.path.join(self.dir, "merkle.lmdb"),
            MerkleDatabase.create_index_configuration())

        self.state_view_factory = NativeStateViewFactory(self.state_db)

        self.block_manager = BlockManager()
        self.block_manager.add_store("commit_store", self.block_store)

        context = create_context('secp256k1')
        private_key = context.new_random_private_key()
        crypto_factory = CryptoFactory(context)
        self.signer = crypto_factory.new_signer(private_key)

        identity_private_key = context.new_random_private_key()
        self.identity_signer = crypto_factory.new_signer(identity_private_key)
        chain_head = None
        if with_genesis:
            self.genesis_block = self.generate_genesis_block()
            chain_head = self.genesis_block
            self.block_manager.put([chain_head.block])
            self.block_manager.persist(chain_head.block.header_signature,
                                       "commit_store")

        self.block_publisher = BlockPublisher(
            block_manager=self.block_manager,
            transaction_executor=MockTransactionExecutor(),
            transaction_committed=self.block_store.has_transaction,
            batch_committed=self.block_store.has_batch,
            state_view_factory=self.state_view_factory,
            block_sender=self.block_sender,
            batch_sender=self.block_sender,
            chain_head=chain_head.block,
            identity_signer=self.identity_signer,
            data_dir=None,
            config_dir=None,
            permission_verifier=MockPermissionVerifier(),
            batch_observers=[])

    @property
    def chain_head(self):
        return self.block_store.chain_head

    def generate_block(self,
                       previous_block=None,
                       add_to_store=False,
                       add_to_cache=False,
                       batch_count=1,
                       batches=None,
                       status=BlockStatus.Unknown,
                       invalid_consensus=False,
                       invalid_batch=False,
                       invalid_signature=False,
                       weight=0):

        previous = self._get_block(previous_block)
        if previous is None:
            previous = self.chain_head

        header = BlockHeader(
            previous_block_id=previous.identifier,
            signer_public_key=self.identity_signer.get_public_key().as_hex(),
            block_num=previous.block_num + 1)

        block_builder = BlockBuilder(header)
        if batches:
            block_builder.add_batches(batches)

        if batch_count != 0:
            block_builder.add_batches(
                [self._generate_batch() for _ in range(batch_count)])

        if invalid_batch:
            block_builder.add_batches(
                [self._generate_batch_from_payload('BAD')])

        block_builder.set_state_hash('0' * 70)

        consensus = mock_consensus.BlockPublisher()
        consensus.finalize_block(block_builder.block_header, weight=weight)

        if invalid_consensus:
            block_builder.block_header.consensus = b'BAD'

        header_bytes = block_builder.block_header.SerializeToString()
        if invalid_signature:
            block_builder.set_signature('BAD')
        else:
            signature = self.identity_signer.sign(header_bytes)
            block_builder.set_signature(signature)

        block_wrapper = BlockWrapper(block_builder.build_block())

        if batches:
            block_wrapper.block.batches.extend(batches)

        if batch_count:
            block_wrapper.block.batches.extend(
                [self.generate_batch() for _ in range(batch_count)])

        if invalid_signature:
            block_wrapper.block.header_signature = "BAD"

        if invalid_consensus:
            block_wrapper.header.consensus = b'BAD'

        block_wrapper.status = status

        self.block_manager.put([block_wrapper.block])
        if add_to_cache:
            self.block_cache[block_wrapper.identifier] = block_wrapper

        if add_to_store:
            self.block_store[block_wrapper.identifier] = block_wrapper

        LOGGER.debug("Generated %s", dumps_block(block_wrapper))
        return block_wrapper

    def generate_chain(self,
                       root_block,
                       blocks,
                       params=None,
                       exclude_head=True):
        """
        Generate a new chain based on the root block and place it in the
        block cache.
        """
        if params is None:
            params = {}

        if root_block is None:
            previous = self.generate_genesis_block()
            self.block_store[previous.identifier] = previous
        else:
            previous = self._get_block(root_block)

        try:
            block_defs = [self._block_def(**params) for _ in range(blocks)]
            if exclude_head:
                block_defs[-1] = self._block_def()
        except TypeError:
            block_defs = blocks

        out = []
        for block_def in block_defs:
            new_block = self.generate_block(previous_block=previous,
                                            **block_def)
            out.append(new_block)
            previous = new_block
        return out

    def create_block(self,
                     payload='payload',
                     batch_count=1,
                     previous_block_id=NULL_BLOCK_IDENTIFIER,
                     block_num=0):
        header = BlockHeader(
            previous_block_id=previous_block_id,
            signer_public_key=self.identity_signer.get_public_key().as_hex(),
            block_num=block_num)

        block_builder = BlockBuilder(header)
        block_builder.add_batches([
            self._generate_batch_from_payload(payload)
            for _ in range(batch_count)
        ])
        block_builder.set_state_hash('0' * 70)

        header_bytes = block_builder.block_header.SerializeToString()
        signature = self.identity_signer.sign(header_bytes)
        block_builder.set_signature(signature)

        block_wrapper = BlockWrapper(block_builder.build_block())
        LOGGER.debug("Generated %s", dumps_block(block_wrapper))
        return block_wrapper

    def generate_genesis_block(self):
        return self.create_block(payload='Genesis',
                                 previous_block_id=NULL_BLOCK_IDENTIFIER,
                                 block_num=0)

    def _block_def(self,
                   add_to_store=False,
                   add_to_cache=False,
                   batch_count=1,
                   status=BlockStatus.Unknown,
                   invalid_consensus=False,
                   invalid_batch=False,
                   invalid_signature=False,
                   weight=0):
        return {
            "add_to_cache": add_to_cache,
            "add_to_store": add_to_store,
            "batch_count": batch_count,
            "status": status,
            "invalid_consensus": invalid_consensus,
            "invalid_batch": invalid_batch,
            "invalid_signature": invalid_signature,
            "weight": weight
        }

    def _get_block_id(self, block):
        if block is None:
            return None
        elif isinstance(block, (Block, BlockWrapper)):
            return block.header_signature
        else:
            return str(block)

    def _get_block(self, block):
        if block is None:
            return None
        elif isinstance(block, Block):
            return BlockWrapper(block)
        elif isinstance(block, BlockWrapper):
            return block
        elif isinstance(block, str):
            return self.block_cache[block]
        else:  # WTF try something crazy
            return self.block_cache[str(block)]

    def generate_batch(self, txn_count=2, missing_deps=False, txns=None):

        batch = self._generate_batch(txn_count, missing_deps, txns)

        LOGGER.debug("Generated Batch:\n%s", dumps_batch(batch))

        return batch

    def _generate_batch(self, txn_count=2, missing_deps=False, txns=None):
        if txns is None:
            txns = []

        if txn_count != 0:
            txns += [
                self.generate_transaction('txn_' + str(i))
                for i in range(txn_count)
            ]

        if missing_deps:
            target_txn = txns[-1]
            txn_missing_deps = self.generate_transaction(
                payload='this one has a missing dependency',
                deps=[target_txn.header_signature])
            # replace the targeted txn with the missing deps txn
            txns[-1] = txn_missing_deps

        batch_header = BatchHeader(
            signer_public_key=self.signer.get_public_key().as_hex(),
            transaction_ids=[txn.header_signature
                             for txn in txns]).SerializeToString()

        batch = Batch(header=batch_header,
                      header_signature=self.signer.sign(batch_header),
                      transactions=txns)

        return batch

    def generate_transaction(self, payload='txn', deps=None):
        payload_encoded = payload.encode('utf-8')
        hasher = hashlib.sha512()
        hasher.update(payload_encoded)

        txn_header = TransactionHeader(
            dependencies=([] if deps is None else deps),
            batcher_public_key=self.signer.get_public_key().as_hex(),
            family_name='test',
            family_version='1',
            nonce=_generate_id(16),
            payload_sha512=hasher.hexdigest().encode(),
            signer_public_key=self.signer.get_public_key().as_hex(
            )).SerializeToString()

        txn = Transaction(header=txn_header,
                          header_signature=self.signer.sign(txn_header),
                          payload=payload_encoded)

        return txn

    def _generate_batch_from_payload(self, payload):
        txn = self.generate_transaction(payload)

        batch_header = BatchHeader(
            signer_public_key=self.signer.get_public_key().as_hex(),
            transaction_ids=[txn.header_signature]).SerializeToString()

        batch = Batch(header=batch_header,
                      header_signature=self.signer.sign(batch_header),
                      transactions=[txn])
        return batch