def compute_state_hashes_wo_scheduler(self): """Creates a state hash from the state updates from each txn in a valid batch. Returns state_hashes (list of str): The merkle roots from state changes in 1 or more blocks in the yaml file. """ tree = MerkleDatabase(database=DictDatabase()) state_hashes = [] updates = {} for batch in self._batches: b_id = batch.header_signature result = self._batch_results[b_id] if result.is_valid: for txn in batch.transactions: txn_id = txn.header_signature _, address_values = self._txn_execution[txn_id] batch_updates = {} for pair in address_values: batch_updates.update({a: pair[a] for a in pair.keys()}) # since this is entirely serial, any overwrite # of an address is expected and desirable. updates.update(batch_updates) # This handles yaml files that have state roots in them if result.state_hash is not None: s_h = tree.update(set_items=updates, virtual=False) tree.set_merkle_root(merkle_root=s_h) state_hashes.append(s_h) if len(state_hashes) == 0: state_hashes.append(tree.update(set_items=updates)) return state_hashes
def compute_state_hashes_wo_scheduler(self): """Creates a state hash from the state updates from each txn in a valid batch. Returns state_hashes (list of str): The merkle roots from state changes in 1 or more blocks in the yaml file. """ tree = MerkleDatabase(database=DictDatabase()) state_hashes = [] updates = {} for batch in self._batches: b_id = batch.header_signature result = self._batch_results[b_id] if result.is_valid: for txn in batch.transactions: txn_id = txn.header_signature _, address_values = self._txn_execution[txn_id] batch_updates = {} for pair in address_values: batch_updates.update({a: pair[a] for a in pair.keys()}) # since this is entirely serial, any overwrite # of an address is expected and desirable. updates.update(batch_updates) # This handles yaml files that have state roots in them if result.state_hash is not None: s_h = tree.update(set_items=updates, virtual=False) tree.set_merkle_root(merkle_root=s_h) state_hashes.append(s_h) if not state_hashes: state_hashes.append(tree.update(set_items=updates)) return state_hashes
class TestSawtoothMerkleTrie: def __init__(self): self.dir = '/tmp/sawtooth' # tempfile.mkdtemp() self.file = os.path.join(self.dir, 'merkle.lmdb') self.lmdb = lmdb_nolock_database.LMDBNoLockDatabase(self.file, 'n') self.trie = MerkleDatabase(self.lmdb) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.trie.close() # assertions def assert_value_at_address(self, address, value, ishash=False): assert self.get(address, ishash) == value def assert_no_key(self, key): with pytest.raises(KeyError): self.get(key) def assert_root(self, expected): assert expected == self.get_merkle_root() def assert_not_root(self, *not_roots): root = self.get_merkle_root() for not_root in not_roots: assert root != not_root # trie accessors # For convenience, assume keys are not hashed # unless otherwise indicated. def set(self, key, val, ishash=False): key_ = key if ishash else _hash(key) return self.trie.set(key_, val) def get(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.get(key_) def delete(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.delete(key_) def set_merkle_root(self, root): self.trie.set_merkle_root(root) def get_merkle_root(self): return self.trie.get_merkle_root() def update(self, set_items, delete_items=None, virtual=True): return self.trie.update(set_items, delete_items, virtual=virtual)
def create_view(self, state_root_hash=None): """Creates a StateView for the given state root hash. Args: state_root_hash (str): The state root hash of the state view to return. If None, returns the state view for the Returns: StateView: state view locked to the given root hash. """ # Create a default Merkle database and if we have a state root hash, # update the Merkle database's root to that merkle_db = MerkleDatabase(self._database) if state_root_hash is not None: merkle_db.set_merkle_root(state_root_hash) return StateView(merkle_db)
def create_view(self, state_root_hash=None): """Creates a StateView for the given state root hash. Args: state_root_hash (str): The state root hash of the state view to return. If None, returns the state view for the Returns: StateView: state view locked to the given root hash. """ # Create a default Merkle database and if we have a state root hash, # update the Merkle database's root to that merkle_db = MerkleDatabase(self._database) if state_root_hash is not None: merkle_db.set_merkle_root(state_root_hash) return StateView(merkle_db)
class ClientStateGetRequestHandler(object): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, message, responder): error = False status = None request = client_pb2.ClientStateGetRequest() try: request.ParseFromString(message.content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = client_pb2.ClientStateGetResponse.NORESOURCE LOGGER.info(e) error = True except DecodeError: LOGGER.info( "Expected protobuf of class %s failed to " "deserialize", request) error = True if error: response = client_pb2.ClientStateGetResponse( status=status or client_pb2.ClientStateGetResponse.ERROR) else: address = request.address try: value = self._tree.get(address) status = client_pb2.ClientStateGetResponse.OK except KeyError: status = client_pb2.ClientStateGetResponse.NORESOURCE LOGGER.debug("No entry at state address %s", address) error = True except ValueError: status = client_pb2.ClientStateGetResponse.NONLEAF LOGGER.debug("Node at state address %s is a nonleaf", address) error = True response = client_pb2.ClientStateGetResponse(status=status) if not error: response.value = value responder.send( validator_pb2.Message( sender=message.sender, message_type=validator_pb2.Message.CLIENT_STATE_GET_RESPONSE, correlation_id=message.correlation_id, content=response.SerializeToString()))
def compute_state_hashes_wo_scheduler(self, base_dir): """Creates a state hash from the state updates from each txn in a valid batch. Returns state_hashes (list of str): The merkle roots from state changes in 1 or more blocks in the yaml file. """ database = NativeLmdbDatabase( os.path.join(base_dir, 'compute_state_hashes_wo_scheduler.lmdb'), indexes=MerkleDatabase.create_index_configuration(), _size=10 * 1024 * 1024) tree = MerkleDatabase(database=database) state_hashes = [] updates = {} for batch in self._batches: b_id = batch.header_signature result = self._batch_results[b_id] if result.is_valid: for txn in batch.transactions: txn_id = txn.header_signature _, address_values, deletes = self._txn_execution[txn_id] batch_updates = {} for pair in address_values: batch_updates.update({a: pair[a] for a in pair.keys()}) # since this is entirely serial, any overwrite # of an address is expected and desirable. updates.update(batch_updates) for address in deletes: if address in updates: del updates[address] # This handles yaml files that have state roots in them if result.state_hash is not None: s_h = tree.update(set_items=updates, virtual=False) tree.set_merkle_root(merkle_root=s_h) state_hashes.append(s_h) if not state_hashes: state_hashes.append(tree.update(set_items=updates)) return state_hashes
def compute_state_hashes_wo_scheduler(self, base_dir): """Creates a state hash from the state updates from each txn in a valid batch. Returns state_hashes (list of str): The merkle roots from state changes in 1 or more blocks in the yaml file. """ database = NativeLmdbDatabase( os.path.join(base_dir, 'compute_state_hashes_wo_scheduler.lmdb'), indexes=MerkleDatabase.create_index_configuration(), _size=10 * 1024 * 1024) tree = MerkleDatabase(database=database) state_hashes = [] updates = {} for batch in self._batches: b_id = batch.header_signature result = self._batch_results[b_id] if result.is_valid: for txn in batch.transactions: txn_id = txn.header_signature _, address_values, deletes = self._txn_execution[txn_id] batch_updates = {} for pair in address_values: batch_updates.update({a: pair[a] for a in pair.keys()}) # since this is entirely serial, any overwrite # of an address is expected and desirable. updates.update(batch_updates) for address in deletes: if address in updates: del updates[address] # This handles yaml files that have state roots in them if result.state_hash is not None: s_h = tree.update(set_items=updates, virtual=False) tree.set_merkle_root(merkle_root=s_h) state_hashes.append(s_h) if not state_hashes: state_hashes.append(tree.update(set_items=updates)) return state_hashes
class StateGetRequestHandler(Handler): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, identity, message_content): request = client_pb2.ClientStateGetRequest() resp_proto = client_pb2.ClientStateGetResponse status = resp_proto.OK try: request.ParseFromString(message_content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = resp_proto.NORESOURCE LOGGER.debug(e) except DecodeError: status = resp_proto.ERROR LOGGER.info( "Expected protobuf of class %s failed to " "deserialize", request) if status != resp_proto.OK: response = resp_proto(status=status) else: address = request.address try: value = self._tree.get(address) except KeyError: status = resp_proto.NORESOURCE LOGGER.debug("No entry at state address %s", address) except ValueError: status = resp_proto.NONLEAF LOGGER.debug("Node at state address %s is a nonleaf", address) response = resp_proto(status=status) if status == resp_proto.OK: response.value = value return HandlerResult( status=HandlerStatus.RETURN, message_out=response, message_type=validator_pb2.Message.CLIENT_STATE_GET_RESPONSE)
class ClientStateListRequestHandler(object): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, message, responder): error = False status = None request = client_pb2.ClientStateListRequest() try: request.ParseFromString(message.content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = client_pb2.ClientStateListResponse.NORESOURCE LOGGER.info(e) error = True except DecodeError: LOGGER.info( "Expected protobuf of class %s failed to " "deserialize", request) error = True if error: response = client_pb2.ClientStateListResponse( status=status or client_pb2.ClientStateListResponse.ERROR) else: prefix = request.prefix leaves = self._tree.leaves(prefix) if len(leaves) == 0: status = client_pb2.ClientStateListResponse.NORESOURCE response = client_pb2.ClientStateListResponse(status=status) else: status = client_pb2.ClientStateListResponse.OK entries = [Entry(address=a, data=v) for a, v in leaves.items()] response = client_pb2.ClientStateListResponse(status=status, entries=entries) responder.send( validator_pb2.Message( sender=message.sender, message_type=validator_pb2.Message.CLIENT_STATE_LIST_RESPONSE, correlation_id=message.correlation_id, content=response.SerializeToString()))
class StateGetRequestHandler(Handler): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, identity, message_content): request = client_pb2.ClientStateGetRequest() resp_proto = client_pb2.ClientStateGetResponse status = resp_proto.OK try: request.ParseFromString(message_content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = resp_proto.NORESOURCE LOGGER.debug(e) except DecodeError: status = resp_proto.ERROR LOGGER.info("Expected protobuf of class %s failed to " "deserialize", request) if status != resp_proto.OK: response = resp_proto(status=status) else: address = request.address try: value = self._tree.get(address) except KeyError: status = resp_proto.NORESOURCE LOGGER.debug("No entry at state address %s", address) except ValueError: status = resp_proto.NONLEAF LOGGER.debug("Node at state address %s is a nonleaf", address) response = resp_proto(status=status) if status == resp_proto.OK: response.value = value return HandlerResult( status=HandlerStatus.RETURN, message_out=response, message_type=validator_pb2.Message.CLIENT_STATE_GET_RESPONSE)
class StateListRequestHandler(Handler): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, identity, message_content): request = client_pb2.ClientStateListRequest() resp_proto = client_pb2.ClientStateListResponse status = resp_proto.OK try: request.ParseFromString(message_content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = resp_proto.NORESOURCE LOGGER.debug(e) except DecodeError: status = resp_proto.ERROR LOGGER.info( "Expected protobuf of class %s failed to " "deserialize", request) if status != resp_proto.OK: response = resp_proto(status=status) else: prefix = request.prefix leaves = self._tree.leaves(prefix) if len(leaves) == 0: status = resp_proto.NORESOURCE response = resp_proto(status=status) else: entries = [Entry(address=a, data=v) for a, v in leaves.items()] response = resp_proto(status=status, entries=entries) return HandlerResult( status=HandlerStatus.RETURN, message_out=response, message_type=validator_pb2.Message.CLIENT_STATE_LIST_RESPONSE)
class StateListRequestHandler(Handler): def __init__(self, database): self._tree = MerkleDatabase(database) def handle(self, identity, message_content): request = client_pb2.ClientStateListRequest() resp_proto = client_pb2.ClientStateListResponse status = resp_proto.OK try: request.ParseFromString(message_content) self._tree.set_merkle_root(request.merkle_root) except KeyError as e: status = resp_proto.NORESOURCE LOGGER.debug(e) except DecodeError: status = resp_proto.ERROR LOGGER.info("Expected protobuf of class %s failed to " "deserialize", request) if status != resp_proto.OK: response = resp_proto(status=status) else: prefix = request.prefix leaves = self._tree.leaves(prefix) if len(leaves) == 0: status = resp_proto.NORESOURCE response = resp_proto(status=status) else: entries = [Entry(address=a, data=v) for a, v in leaves.items()] response = resp_proto(status=status, entries=entries) return HandlerResult( status=HandlerStatus.RETURN, message_out=response, message_type=validator_pb2.Message.CLIENT_STATE_LIST_RESPONSE)
class TestSawtoothMerkleTrie(unittest.TestCase): def setUp(self): self.lmdb = lmdb_nolock_database.LMDBNoLockDatabase( "/home/vagrant/merkle.lmdb", 'n') self.trie = MerkleDatabase(self.lmdb) def tearDown(self): self.trie.close() def test_merkle_trie_root_advance(self): value = {"name": "foo", "value": 1} orig_root = self.trie.get_merkle_root() new_root = self.trie.set(MerkleDatabase.hash("foo"), value) with self.assertRaises(KeyError): self.trie.get(MerkleDatabase.hash("foo")) self.trie.set_merkle_root(new_root) self.assertEqual(self.trie.get(MerkleDatabase.hash("foo")), value) def test_merkle_trie_delete(self): value = {"name": "bar", "value": 1} new_root = self.trie.set(MerkleDatabase.hash("bar"), value) self.trie.set_merkle_root(new_root) self.assertEqual(self.trie.get(MerkleDatabase.hash("bar")), value) del_root = self.trie.delete(MerkleDatabase.hash("bar")) self.trie.set_merkle_root(del_root) with self.assertRaises(KeyError): self.trie.get(MerkleDatabase.hash("bar")) def test_merkle_trie_update(self): value = ''.join(random.choice(string.ascii_lowercase) for _ in range(512)) keys = [] for i in range(1000): key = ''.join(random.choice(string.ascii_lowercase) for _ in range(10)) keys.append(key) hash = MerkleDatabase.hash(key) new_root = self.trie.set(hash, {key: value}) self.trie.set_merkle_root(new_root) set_items = {} for key in random.sample(keys, 50): hash = MerkleDatabase.hash(key) thing = {key: 5.0} set_items[hash] = thing update_root = self.trie.update(set_items) self.trie.set_merkle_root(update_root) for address in set_items: self.assertEqual(self.trie.get(address), set_items[address])
class TestSawtoothMerkleTrie(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, 'merkle.lmdb') self.lmdb = NativeLmdbDatabase( self.file, indexes=MerkleDatabase.create_index_configuration(), _size=120 * 1024 * 1024) self.trie = MerkleDatabase(self.lmdb) def tearDown(self): self.trie.close() shutil.rmtree(self.dir) def test_merkle_trie_root_advance(self): value = {'name': 'foo', 'value': 1} orig_root = self.get_merkle_root() new_root = self.set('foo', value) self.assert_root(orig_root) self.assert_no_key('foo') self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('foo', value) def test_merkle_trie_delete(self): value = {'name': 'bar', 'value': 1} new_root = self.set('bar', value) self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('bar', value) # deleting an invalid key should raise an error with self.assertRaises(KeyError): self.delete('barf') del_root = self.delete('bar') # del_root hasn't been set yet, so address should still have value self.assert_root(new_root) self.assert_value_at_address('bar', value) self.set_merkle_root(del_root) self.assert_root(del_root) self.assert_no_key('bar') def test_merkle_trie_update(self): init_root = self.get_merkle_root() values = {} key_hashes = { key: _hash(key) for key in (_random_string(10) for _ in range(1000)) } for key, hashed in key_hashes.items(): value = {key: _random_string(512)} new_root = self.set(hashed, value, ishash=True) values[hashed] = value self.set_merkle_root(new_root) self.assert_not_root(init_root) for address, value in values.items(): self.assert_value_at_address(address, value, ishash=True) set_items = { hashed: { key: 5.0 } for key, hashed in random.sample(key_hashes.items(), 50) } values.update(set_items) delete_items = { hashed for hashed in random.sample(list(key_hashes.values()), 50) } # make sure there are no sets and deletes of the same key delete_items = delete_items - set_items.keys() for addr in delete_items: del values[addr] virtual_root = self.update(set_items, delete_items, virtual=True) # virtual root shouldn't match actual contents of tree with self.assertRaises(KeyError): self.set_merkle_root(virtual_root) actual_root = self.update(set_items, delete_items, virtual=False) # the virtual root should be the same as the actual root self.assertEqual(virtual_root, actual_root) # neither should be the root yet self.assert_not_root(virtual_root, actual_root) self.set_merkle_root(actual_root) self.assert_root(actual_root) for address, value in values.items(): self.assert_value_at_address(address, value, ishash=True) for address in delete_items: with self.assertRaises(KeyError): self.get(address, ishash=True) def test_merkle_trie_leaf_iteration(self): new_root = self.update( { "010101": { "my_data": 1 }, "010202": { "my_data": 2 }, "010303": { "my_data": 3 } }, [], virtual=False) # iterate over the empty trie iterator = iter(self.trie) with self.assertRaises(StopIteration): next(iterator) self.set_merkle_root(new_root) # Test complete trie iteration self.assertEqual([("010101", { "my_data": 1 }), ("010202", { "my_data": 2 }), ("010303", { "my_data": 3 })], [entry for entry in iter(self.trie)]) # Test prefixed iteration self.assertEqual([("010202", { "my_data": 2 })], [entry for entry in self.trie.leaves('0102')]) # assertions def assert_value_at_address(self, address, value, ishash=False): self.assertEqual(self.get(address, ishash), value, 'Wrong value') def assert_no_key(self, key): with self.assertRaises(KeyError): self.get(key) def assert_root(self, expected): self.assertEqual(expected, self.get_merkle_root(), 'Wrong root') def assert_not_root(self, *not_roots): root = self.get_merkle_root() for not_root in not_roots: self.assertNotEqual(root, not_root, 'Wrong root') # trie accessors # For convenience, assume keys are not hashed # unless otherwise indicated. def set(self, key, val, ishash=False): key_ = key if ishash else _hash(key) return self.trie.set(key_, val) def get(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.get(key_) def delete(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.delete(key_) def set_merkle_root(self, root): self.trie.set_merkle_root(root) def get_merkle_root(self): return self.trie.get_merkle_root() def update(self, set_items, delete_items=None, virtual=True): return self.trie.update(set_items, delete_items, virtual=virtual)
class TestSawtoothMerkleTrie(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, 'merkle.lmdb') self.lmdb = lmdb_nolock_database.LMDBNoLockDatabase(self.file, 'n') self.trie = MerkleDatabase(self.lmdb) def tearDown(self): self.trie.close() def test_merkle_trie_root_advance(self): value = {'name': 'foo', 'value': 1} orig_root = self.get_merkle_root() new_root = self.set('foo', value) self.assert_root(orig_root) self.assert_no_key('foo') self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('foo', value) def test_merkle_trie_delete(self): value = {'name': 'bar', 'value': 1} new_root = self.set('bar', value) self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('bar', value) # deleting an invalid key should raise an error with self.assertRaises(KeyError): self.delete('barf') del_root = self.delete('bar') # del_root hasn't been set yet, so address should still have value self.assert_root(new_root) self.assert_value_at_address('bar', value) self.set_merkle_root(del_root) self.assert_root(del_root) self.assert_no_key('bar') def test_merkle_trie_update(self): init_root = self.get_merkle_root() values = {} key_hashes = { key: _hash(key) for key in (_random_string(10) for _ in range(1000)) } for key, hashed in key_hashes.items(): value = {key: _random_string(512)} new_root = self.set(hashed, value, ishash=True) values[hashed] = value self.set_merkle_root(new_root) self.assert_not_root(init_root) for address, value in values.items(): self.assert_value_at_address(address, value, ishash=True) set_items = { hashed: { key: 5.0 } for key, hashed in random.sample(key_hashes.items(), 50) } values.update(set_items) delete_items = { hashed for hashed in random.sample(list(key_hashes.values()), 50) } # make sure there are no sets and deletes of the same key delete_items = delete_items - set_items.keys() for addr in delete_items: del values[addr] virtual_root = self.update(set_items, delete_items, virtual=True) # virtual root shouldn't match actual contents of tree with self.assertRaises(KeyError): self.set_merkle_root(virtual_root) actual_root = self.update(set_items, delete_items, virtual=False) # the virtual root should be the same as the actual root self.assertEqual(virtual_root, actual_root) # neither should be the root yet self.assert_not_root(virtual_root, actual_root) self.set_merkle_root(actual_root) self.assert_root(actual_root) for address, value in values.items(): self.assert_value_at_address(address, value, ishash=True) for address in delete_items: with self.assertRaises(KeyError): self.get(address, ishash=True) # assertions def assert_value_at_address(self, address, value, ishash=False): self.assertEqual(self.get(address, ishash), value, 'Wrong value') def assert_no_key(self, key): with self.assertRaises(KeyError): self.get(key) def assert_root(self, expected): self.assertEqual(expected, self.get_merkle_root(), 'Wrong root') def assert_not_root(self, *not_roots): root = self.get_merkle_root() for not_root in not_roots: self.assertNotEqual(root, not_root, 'Wrong root') # trie accessors # For convenience, assume keys are not hashed # unless otherwise indicated. def set(self, key, val, ishash=False): key_ = key if ishash else _hash(key) return self.trie.set(key_, val) def get(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.get(key_) def delete(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.delete(key_) def set_merkle_root(self, root): self.trie.set_merkle_root(root) def get_merkle_root(self): return self.trie.get_merkle_root() def update(self, set_items, delete_items=None, virtual=True): return self.trie.update(set_items, delete_items, virtual=virtual)
class TestIdentityView(unittest.TestCase): def __init__(self, test_name): super().__init__(test_name) self._temp_dir = None def setUp(self): self._temp_dir = tempfile.mkdtemp() self._database = NativeLmdbDatabase( os.path.join(self._temp_dir, 'test_identity_view.lmdb'), _size=10 * 1024 * 1024) self._tree = MerkleDatabase(self._database) def tearDown(self): shutil.rmtree(self._temp_dir) def test_identityview_roles(self): """Tests get_role and get_roles get the correct Roles and the IdentityViewFactory produces the correct view of the database. Notes: 1. Create an empty MerkleDatabase and update it with one serialized RoleList. 2. Assert that get_role returns that named Role. 3. Assert that get_role returns None for a name that doesn't correspond to a Role. 4. Assert that all the Roles are returned by get_roles. 5. Update the MerkleDatabase with another serialized RoleList with a different name. 6. Repeat 2. 7. Repeat 3. 8. Repeat 4. """ state_view_factory = StateViewFactory(self._database) identity_view_factory = identity_view.IdentityViewFactory( state_view_factory=state_view_factory) # 1. role_list = identity_pb2.RoleList() role1 = role_list.roles.add() role1_name = "sawtooth.test.example1" role1.name = role1_name role1.policy_name = "this_is_an_example" state_root1 = self._tree.update( set_items={ _get_role_address(role1_name): role_list.SerializeToString() }, virtual=False) # 2. identity_view1 = identity_view_factory.create_identity_view( state_hash=state_root1) self.assertEqual( identity_view1.get_role(role1_name), role1, "IdentityView().get_role returns the correct Role by name.") # 3. self.assertIsNone( identity_view1.get_role("Not-a-Role"), "IdentityView().get_role returns None if there is " "no Role with that name.") # 4. self.assertEqual(identity_view1.get_roles(), [role1], "IdentityView().get_roles returns all the roles in" " State.") # 5. role_list2 = identity_pb2.RoleList() role2 = role_list2.roles.add() role2_name = "sawtooth.test.example2" role2.name = role2_name role2.policy_name = "this_is_another_example" self._tree.set_merkle_root(merkle_root=state_root1) state_root2 = self._tree.update( { _get_role_address(role2_name): role_list2.SerializeToString() }, virtual=False) # 6. identity_view2 = identity_view_factory.create_identity_view( state_hash=state_root2) self.assertEqual( identity_view2.get_role(role2_name), role2, "IdentityView().get_role returns the correct Role by name.") # 7. self.assertIsNone( identity_view2.get_role("not-a-role2"), "IdentityView().get_role returns None for names that don't " "correspond to a Role.") # 8. self.assertEqual( identity_view2.get_roles(), [role1, role2], "IdentityView().get_roles() returns all the Roles in alphabetical " "order by name.") def test_identityview_policy(self): """Tests get_policy and get_policies get the correct Policies and the IdentityViewFactory produces the correct view of the database. Notes: 1. Create an empty MerkleDatabase and update it with one serialized PolicyList. 2. Assert that get_policy returns that named Policy. 3. Assert that get_policy returns None for a name that doesn't correspond to a Policy. 4. Assert that all the Policies are returned by get_policies. 5. Update the MerkleDatabase with another serialized PolicyList with a different name. 6. Repeat 2. 7. Repeat 3. 8. Repeat 4. """ state_view_factory = StateViewFactory(self._database) identity_view_factory = identity_view.IdentityViewFactory( state_view_factory=state_view_factory) # 1. policy_list = identity_pb2.PolicyList() policy1 = policy_list.policies.add() policy1_name = "deny_all_keys" policy1.name = policy1_name state_root1 = self._tree.update( set_items={ _get_policy_address(policy1_name): policy_list.SerializeToString() }, virtual=False) # 2. identity_view1 = identity_view_factory.create_identity_view( state_hash=state_root1) self.assertEqual( identity_view1.get_policy(policy1_name), policy1, "IdentityView().get_policy returns the correct Policy by name.") # 3. self.assertIsNone( identity_view1.get_policy("Not-a-Policy"), "IdentityView().get_policy returns None if " "there is no Policy with that name.") # 4. self.assertEqual(identity_view1.get_policies(), [policy1], "IdentityView().get_policies returns all the " "policies in State.") # 5. policy_list2 = identity_pb2.PolicyList() policy2 = policy_list2.policies.add() policy2_name = "accept_all_keys" policy2.name = policy2_name self._tree.set_merkle_root(merkle_root=state_root1) state_root2 = self._tree.update( { _get_policy_address(policy2_name): policy_list2.SerializeToString() }, virtual=False) # 6. identity_view2 = identity_view_factory.create_identity_view( state_hash=state_root2) self.assertEqual( identity_view2.get_policy(policy2_name), policy2, "IdentityView().get_policy returns the correct Policy by name.") # 7. self.assertIsNone( identity_view2.get_policy("not-a-policy2"), "IdentityView().get_policy returns None for names that don't " "correspond to a Policy.") # 8. self.assertEqual( identity_view2.get_policies(), [policy2, policy1], "IdentityView().get_policies returns all the Policies in " "alphabetical order by name.")
class TestIdentityView(unittest.TestCase): def __init__(self, test_name): super().__init__(test_name) self._temp_dir = None def setUp(self): self._temp_dir = tempfile.mkdtemp() self._database = NativeLmdbDatabase( os.path.join(self._temp_dir, 'test_identity_view.lmdb'), indexes=MerkleDatabase.create_index_configuration(), _size=10 * 1024 * 1024) self._tree = MerkleDatabase(self._database) def tearDown(self): shutil.rmtree(self._temp_dir) def test_identityview_roles(self): """Tests get_role and get_roles get the correct Roles and the IdentityViewFactory produces the correct view of the database. Notes: 1. Create an empty MerkleDatabase and update it with one serialized RoleList. 2. Assert that get_role returns that named Role. 3. Assert that get_role returns None for a name that doesn't correspond to a Role. 4. Assert that all the Roles are returned by get_roles. 5. Update the MerkleDatabase with another serialized RoleList with a different name. 6. Repeat 2. 7. Repeat 3. 8. Repeat 4. """ state_view_factory = StateViewFactory(self._database) identity_view_factory = identity_view.IdentityViewFactory( state_view_factory=state_view_factory) # 1. role_list = identity_pb2.RoleList() role1 = role_list.roles.add() role1_name = "sawtooth.test.example1" role1.name = role1_name role1.policy_name = "this_is_an_example" state_root1 = self._tree.update( set_items={ _get_role_address(role1_name): role_list.SerializeToString() }, virtual=False) # 2. identity_view1 = identity_view_factory.create_identity_view( state_hash=state_root1) self.assertEqual( identity_view1.get_role(role1_name), role1, "IdentityView().get_role returns the correct Role by name.") # 3. self.assertIsNone( identity_view1.get_role("Not-a-Role"), "IdentityView().get_role returns None if there is " "no Role with that name.") # 4. self.assertEqual(identity_view1.get_roles(), [role1], "IdentityView().get_roles returns all the roles in" " State.") # 5. role_list2 = identity_pb2.RoleList() role2 = role_list2.roles.add() role2_name = "sawtooth.test.example2" role2.name = role2_name role2.policy_name = "this_is_another_example" self._tree.set_merkle_root(merkle_root=state_root1) state_root2 = self._tree.update( { _get_role_address(role2_name): role_list2.SerializeToString() }, virtual=False) # 6. identity_view2 = identity_view_factory.create_identity_view( state_hash=state_root2) self.assertEqual( identity_view2.get_role(role2_name), role2, "IdentityView().get_role returns the correct Role by name.") # 7. self.assertIsNone( identity_view2.get_role("not-a-role2"), "IdentityView().get_role returns None for names that don't " "correspond to a Role.") # 8. self.assertEqual( identity_view2.get_roles(), [role1, role2], "IdentityView().get_roles() returns all the Roles in alphabetical " "order by name.") def test_identityview_policy(self): """Tests get_policy and get_policies get the correct Policies and the IdentityViewFactory produces the correct view of the database. Notes: 1. Create an empty MerkleDatabase and update it with one serialized PolicyList. 2. Assert that get_policy returns that named Policy. 3. Assert that get_policy returns None for a name that doesn't correspond to a Policy. 4. Assert that all the Policies are returned by get_policies. 5. Update the MerkleDatabase with another serialized PolicyList with a different name. 6. Repeat 2. 7. Repeat 3. 8. Repeat 4. """ state_view_factory = StateViewFactory(self._database) identity_view_factory = identity_view.IdentityViewFactory( state_view_factory=state_view_factory) # 1. policy_list = identity_pb2.PolicyList() policy1 = policy_list.policies.add() policy1_name = "deny_all_keys" policy1.name = policy1_name state_root1 = self._tree.update( set_items={ _get_policy_address(policy1_name): policy_list.SerializeToString() }, virtual=False) # 2. identity_view1 = identity_view_factory.create_identity_view( state_hash=state_root1) self.assertEqual( identity_view1.get_policy(policy1_name), policy1, "IdentityView().get_policy returns the correct Policy by name.") # 3. self.assertIsNone( identity_view1.get_policy("Not-a-Policy"), "IdentityView().get_policy returns None if " "there is no Policy with that name.") # 4. self.assertEqual(identity_view1.get_policies(), [policy1], "IdentityView().get_policies returns all the " "policies in State.") # 5. policy_list2 = identity_pb2.PolicyList() policy2 = policy_list2.policies.add() policy2_name = "accept_all_keys" policy2.name = policy2_name self._tree.set_merkle_root(merkle_root=state_root1) state_root2 = self._tree.update( { _get_policy_address(policy2_name): policy_list2.SerializeToString() }, virtual=False) # 6. identity_view2 = identity_view_factory.create_identity_view( state_hash=state_root2) self.assertEqual( identity_view2.get_policy(policy2_name), policy2, "IdentityView().get_policy returns the correct Policy by name.") # 7. self.assertIsNone( identity_view2.get_policy("not-a-policy2"), "IdentityView().get_policy returns None for names that don't " "correspond to a Policy.") # 8. self.assertEqual( identity_view2.get_policies(), [policy2, policy1], "IdentityView().get_policies returns all the Policies in " "alphabetical order by name.")
def test_complex_basecontext_squash(self): """Tests complex context basing and squashing. i=qq,dd dd=0 o=dd,pp pp=1 i=cc,aa +->context_3_2a_1+| o=dd,ll | | i=aa,ab +->context_2a| i=aa aa=0 | o=cc,ab | dd=10 | o=aa,ll ll=1 | sh0->context_1-->sh1| ll=11 +->context_3_2a_2+|->sh1 cc=0 | i=cc,aa +->context_3_2b_1+| ab=1 | o=nn,mm | i=nn,ba mm=0 | +->context_2b| o=mm,ba ba=1 | nn=0 | | mm=1 +->context_3_2b_2+| i=nn,oo ab=0 o=ab,oo oo=1 Notes: Test: 1. Create a context off of the first state hash, set addresses in it, and squash that context, getting a new merkle root. 2. Create 2 contexts with the context in # 1 as the base, and for each of these contexts set addresses to values where the outputs for each are disjoint. 3. For each of these 2 contexts create 2 more contexts each having one of the contexts in # 2 as the base context, and set addresses to values. 4. Squash the 4 contexts from #3 and assert the state hash is equal to a manually computed state hash. """ squash = self.context_manager.get_squash_handler() # 1) inputs_1 = [self._create_address('aa'), self._create_address('ab')] outputs_1 = [self._create_address('cc'), self._create_address('ab')] context_1 = self.context_manager.create_context( state_hash=self.first_state_hash, base_contexts=[], inputs=inputs_1, outputs=outputs_1) self.context_manager.set( context_id=context_1, address_value_list=[{a: v} for a, v in zip( outputs_1, [bytes(i) for i in range(len(outputs_1))])]) sh1 = squash( state_root=self.first_state_hash, context_ids=[context_1], persist=True, clean_up=True) # 2) inputs_2a = [self._create_address('cc'), self._create_address('aa')] outputs_2a = [self._create_address('dd'), self._create_address('ll')] context_2a = self.context_manager.create_context( state_hash=self.first_state_hash, base_contexts=[], inputs=inputs_2a, outputs=outputs_2a) inputs_2b = [self._create_address('cc'), self._create_address('aa')] outputs_2b = [self._create_address('nn'), self._create_address('mm')] context_2b = self.context_manager.create_context( state_hash=sh1, base_contexts=[], inputs=inputs_2b, outputs=outputs_2b) self.context_manager.set( context_id=context_2a, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_2a, range(10, 10 + len(outputs_2a)))] ) self.context_manager.set( context_id=context_2b, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_2b, range(len(outputs_2b)))] ) # 3) inputs_3_2a_1 = [self._create_address('qq'), self._create_address('dd')] outputs_3_2a_1 = [self._create_address('dd'), self._create_address('pp')] context_3_2a_1 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2a], inputs=inputs_3_2a_1, outputs=outputs_3_2a_1 ) inputs_3_2a_2 = [self._create_address('aa')] outputs_3_2a_2 = [self._create_address('aa'), self._create_address('ll')] context_3_2a_2 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2a], inputs=inputs_3_2a_2, outputs=outputs_3_2a_2) inputs_3_2b_1 = [self._create_address('nn'), self._create_address('ab')] outputs_3_2b_1 = [self._create_address('mm'), self._create_address('ba')] context_3_2b_1 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2b], inputs=inputs_3_2b_1, outputs=outputs_3_2b_1) inputs_3_2b_2 = [self._create_address('nn'), self._create_address('oo')] outputs_3_2b_2 = [self._create_address('ab'), self._create_address('oo')] context_3_2b_2 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2b], inputs=inputs_3_2b_2, outputs=outputs_3_2b_2) self.context_manager.set( context_id=context_3_2a_1, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_3_2a_1, range(len(outputs_3_2a_1)))]) self.context_manager.set( context_id=context_3_2a_2, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_3_2a_2, range(len(outputs_3_2a_2)))]) self.context_manager.set( context_id=context_3_2b_1, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_3_2b_1, range(len(outputs_3_2b_1)))]) self.context_manager.set( context_id=context_3_2b_2, address_value_list=[{a: bytes(v)} for a, v in zip(outputs_3_2b_2, range(len(outputs_3_2b_2)))]) # 4) sh2 = squash( state_root=sh1, context_ids=[context_3_2a_1, context_3_2a_2, context_3_2b_1, context_3_2b_2], persist=False, clean_up=True) tree = MerkleDatabase(self.database_results) state_hash_from_1 = tree.update( set_items={a: v for a, v in zip(outputs_1, [bytes(i) for i in range(len(outputs_1))])}, virtual=False) self.assertEquals(state_hash_from_1, sh1, "The manually calculated state hash from the first " "context and the one calculated by squashing that " "state hash should be the same") tree.set_merkle_root(state_hash_from_1) test_sh2 = tree.update(set_items={self._create_address('aa'): bytes(0), self._create_address('ab'): bytes(0), self._create_address('ba'): bytes(1), self._create_address('dd'): bytes(0), self._create_address('ll'): bytes(1), self._create_address('mm'): bytes(0), self._create_address('oo'): bytes(1), self._create_address('pp'): bytes(1), self._create_address('nn'): bytes(0), self._create_address('cc'): bytes(0)}) self.assertEquals(sh2, test_sh2, "Manually calculated and context " "manager calculated merkle hashes " "are the same")
class TestSawtoothMerkleTrie(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, 'merkle.lmdb') self.lmdb = lmdb_nolock_database.LMDBNoLockDatabase( self.file, 'n') self.trie = MerkleDatabase(self.lmdb) def tearDown(self): self.trie.close() def test_merkle_trie_root_advance(self): value = {'name': 'foo', 'value': 1} orig_root = self.get_merkle_root() new_root = self.set('foo', value) self.assert_root(orig_root) self.assert_no_key('foo') self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('foo', value) def test_merkle_trie_delete(self): value = {'name': 'bar', 'value': 1} new_root = self.set('bar', value) self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('bar', value) # deleting an invalid key should raise an error with self.assertRaises(KeyError): self.delete('barf') del_root = self.delete('bar') # del_root hasn't been set yet, so address should still have value self.assert_root(new_root) self.assert_value_at_address('bar', value) self.set_merkle_root(del_root) self.assert_root(del_root) self.assert_no_key('bar') def test_merkle_trie_update(self): init_root = self.get_merkle_root() values = {} key_hashes = { key: _hash(key) for key in (_random_string(10) for _ in range(1000)) } for key, hashed in key_hashes.items(): value = {key: _random_string(512)} new_root = self.set(hashed, value, ishash=True) values[hashed] = value self.set_merkle_root(new_root) self.assert_not_root(init_root) for address, value in values.items(): self.assert_value_at_address( address, value, ishash=True) set_items = { hashed: { key: 5.0 } for key, hashed in random.sample(key_hashes.items(), 50) } values.update(set_items) delete_items = { hashed for hashed in random.sample(list(key_hashes.values()), 50) } # make sure there are no sets and deletes of the same key delete_items = delete_items - set_items.keys() for addr in delete_items: del values[addr] virtual_root = self.update(set_items, delete_items, virtual=True) # virtual root shouldn't match actual contents of tree with self.assertRaises(KeyError): self.set_merkle_root(virtual_root) actual_root = self.update(set_items, delete_items, virtual=False) # the virtual root should be the same as the actual root self.assertEqual(virtual_root, actual_root) # neither should be the root yet self.assert_not_root( virtual_root, actual_root) self.set_merkle_root(actual_root) self.assert_root(actual_root) for address, value in values.items(): self.assert_value_at_address( address, value, ishash=True) for address in delete_items: with self.assertRaises(KeyError): self.get(address, ishash=True) # assertions def assert_value_at_address(self, address, value, ishash=False): self.assertEqual( self.get(address, ishash), value, 'Wrong value') def assert_no_key(self, key): with self.assertRaises(KeyError): self.get(key) def assert_root(self, expected): self.assertEqual( expected, self.get_merkle_root(), 'Wrong root') def assert_not_root(self, *not_roots): root = self.get_merkle_root() for not_root in not_roots: self.assertNotEqual( root, not_root, 'Wrong root') # trie accessors # For convenience, assume keys are not hashed # unless otherwise indicated. def set(self, key, val, ishash=False): key_ = key if ishash else _hash(key) return self.trie.set(key_, val) def get(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.get(key_) def delete(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.delete(key_) def set_merkle_root(self, root): self.trie.set_merkle_root(root) def get_merkle_root(self): return self.trie.get_merkle_root() def update(self, set_items, delete_items=None, virtual=True): return self.trie.update(set_items, delete_items, virtual=virtual)
def test_complex_basecontext_squash(self): """Tests complex context basing and squashing. i=qq,dd dd=0 o=dd,pp pp=1 i=cc,aa +->context_3_2a_1+| o=dd,ll | | i=aa,ab +->context_2a| i=aa aa=0 | o=cc,ab | dd=10 | o=aa,ll ll=1 | sh0->context_1-->sh1| ll=11 +->context_3_2a_2+|->sh1 cc=0 | i=cc,aa +->context_3_2b_1+| ab=1 | o=nn,mm | i=nn,ba mm=0 | +->context_2b| o=mm,ba ba=1 | nn=0 | | mm=1 +->context_3_2b_2+| i=nn,oo ab=0 o=ab,oo oo=1 Notes: Test: 1. Create a context off of the first state hash, set addresses in it, and squash that context, getting a new merkle root. 2. Create 2 contexts with the context in # 1 as the base, and for each of these contexts set addresses to values where the outputs for each are disjoint. 3. For each of these 2 contexts create 2 more contexts each having one of the contexts in # 2 as the base context, and set addresses to values. 4. Squash the 4 contexts from #3 and assert the state hash is equal to a manually computed state hash. """ squash = self.context_manager.get_squash_handler() # 1) inputs_1 = [self._create_address('aa'), self._create_address('ab')] outputs_1 = [self._create_address('cc'), self._create_address('ab')] context_1 = self.context_manager.create_context( state_hash=self.first_state_hash, base_contexts=[], inputs=inputs_1, outputs=outputs_1) self.context_manager.set( context_id=context_1, address_value_list=[{ a: v } for a, v in zip(outputs_1, [bytes(i) for i in range(len(outputs_1))])]) sh1 = squash(state_root=self.first_state_hash, context_ids=[context_1], persist=True, clean_up=True) # 2) inputs_2a = [self._create_address('cc'), self._create_address('aa')] outputs_2a = [self._create_address('dd'), self._create_address('ll')] context_2a = self.context_manager.create_context( state_hash=self.first_state_hash, base_contexts=[], inputs=inputs_2a, outputs=outputs_2a) inputs_2b = [self._create_address('cc'), self._create_address('aa')] outputs_2b = [self._create_address('nn'), self._create_address('mm')] context_2b = self.context_manager.create_context(state_hash=sh1, base_contexts=[], inputs=inputs_2b, outputs=outputs_2b) self.context_manager.set( context_id=context_2a, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_2a, range(10, 10 + len(outputs_2a)))]) self.context_manager.set( context_id=context_2b, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_2b, range(len(outputs_2b)))]) # 3) inputs_3_2a_1 = [ self._create_address('qq'), self._create_address('dd') ] outputs_3_2a_1 = [ self._create_address('dd'), self._create_address('pp') ] context_3_2a_1 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2a], inputs=inputs_3_2a_1, outputs=outputs_3_2a_1) inputs_3_2a_2 = [self._create_address('aa')] outputs_3_2a_2 = [ self._create_address('aa'), self._create_address('ll') ] context_3_2a_2 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2a], inputs=inputs_3_2a_2, outputs=outputs_3_2a_2) inputs_3_2b_1 = [ self._create_address('nn'), self._create_address('ab') ] outputs_3_2b_1 = [ self._create_address('mm'), self._create_address('ba') ] context_3_2b_1 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2b], inputs=inputs_3_2b_1, outputs=outputs_3_2b_1) inputs_3_2b_2 = [ self._create_address('nn'), self._create_address('oo') ] outputs_3_2b_2 = [ self._create_address('ab'), self._create_address('oo') ] context_3_2b_2 = self.context_manager.create_context( state_hash=sh1, base_contexts=[context_2b], inputs=inputs_3_2b_2, outputs=outputs_3_2b_2) self.context_manager.set( context_id=context_3_2a_1, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_3_2a_1, range(len(outputs_3_2a_1)))]) self.context_manager.set( context_id=context_3_2a_2, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_3_2a_2, range(len(outputs_3_2a_2)))]) self.context_manager.set( context_id=context_3_2b_1, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_3_2b_1, range(len(outputs_3_2b_1)))]) self.context_manager.set( context_id=context_3_2b_2, address_value_list=[{ a: bytes(v) } for a, v in zip(outputs_3_2b_2, range(len(outputs_3_2b_2)))]) # 4) sh2 = squash(state_root=sh1, context_ids=[ context_3_2a_1, context_3_2a_2, context_3_2b_1, context_3_2b_2 ], persist=False, clean_up=True) tree = MerkleDatabase(self.database_results) state_hash_from_1 = tree.update(set_items={ a: v for a, v in zip(outputs_1, [bytes(i) for i in range(len(outputs_1))]) }, virtual=False) self.assertEquals( state_hash_from_1, sh1, "The manually calculated state hash from the first " "context and the one calculated by squashing that " "state hash should be the same") tree.set_merkle_root(state_hash_from_1) test_sh2 = tree.update( set_items={ self._create_address('aa'): bytes(0), self._create_address('ab'): bytes(0), self._create_address('ba'): bytes(1), self._create_address('dd'): bytes(0), self._create_address('ll'): bytes(1), self._create_address('mm'): bytes(0), self._create_address('oo'): bytes(1), self._create_address('pp'): bytes(1), self._create_address('nn'): bytes(0), self._create_address('cc'): bytes(0) }) self.assertEquals( sh2, test_sh2, "Manually calculated and context " "manager calculated merkle hashes " "are the same")
class TestSawtoothMerkleTrie(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, 'merkle.lmdb') self.lmdb = NativeLmdbDatabase( self.file, indexes=MerkleDatabase.create_index_configuration(), _size=120 * 1024 * 1024) self.trie = MerkleDatabase(self.lmdb) def tearDown(self): self.trie.close() shutil.rmtree(self.dir) def test_merkle_trie_root_advance(self): value = {'name': 'foo', 'value': 1} orig_root = self.get_merkle_root() new_root = self.set('foo', value) self.assert_root(orig_root) self.assert_no_key('foo') self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('foo', value) def test_merkle_trie_delete(self): value = {'name': 'bar', 'value': 1} new_root = self.set('bar', value) self.set_merkle_root(new_root) self.assert_root(new_root) self.assert_value_at_address('bar', value) # deleting an invalid key should raise an error with self.assertRaises(KeyError): self.delete('barf') del_root = self.delete('bar') # del_root hasn't been set yet, so address should still have value self.assert_root(new_root) self.assert_value_at_address('bar', value) self.set_merkle_root(del_root) self.assert_root(del_root) self.assert_no_key('bar') def test_merkle_trie_update(self): init_root = self.get_merkle_root() values = {} key_hashes = { key: _hash(key) for key in (_random_string(10) for _ in range(1000)) } for key, hashed in key_hashes.items(): value = {key: _random_string(512)} new_root = self.set(hashed, value, ishash=True) values[hashed] = value self.set_merkle_root(new_root) self.assert_not_root(init_root) for address, value in values.items(): self.assert_value_at_address( address, value, ishash=True) set_items = { hashed: { key: 5.0 } for key, hashed in random.sample(key_hashes.items(), 50) } values.update(set_items) delete_items = { hashed for hashed in random.sample(list(key_hashes.values()), 50) } # make sure there are no sets and deletes of the same key delete_items = delete_items - set_items.keys() for addr in delete_items: del values[addr] virtual_root = self.update(set_items, delete_items, virtual=True) # virtual root shouldn't match actual contents of tree with self.assertRaises(KeyError): self.set_merkle_root(virtual_root) actual_root = self.update(set_items, delete_items, virtual=False) # the virtual root should be the same as the actual root self.assertEqual(virtual_root, actual_root) # neither should be the root yet self.assert_not_root( virtual_root, actual_root) self.set_merkle_root(actual_root) self.assert_root(actual_root) for address, value in values.items(): self.assert_value_at_address( address, value, ishash=True) for address in delete_items: with self.assertRaises(KeyError): self.get(address, ishash=True) def test_merkle_trie_leaf_iteration(self): new_root = self.update({ "010101": {"my_data": 1}, "010202": {"my_data": 2}, "010303": {"my_data": 3} }, [], virtual=False) # iterate over the empty trie iterator = iter(self.trie) with self.assertRaises(StopIteration): next(iterator) self.set_merkle_root(new_root) # Test complete trie iteration self.assertEqual( [("010101", {"my_data": 1}), ("010202", {"my_data": 2}), ("010303", {"my_data": 3})], [entry for entry in iter(self.trie)]) # Test prefixed iteration self.assertEqual([("010202", {"my_data": 2})], [entry for entry in self.trie.leaves('0102')]) # assertions def assert_value_at_address(self, address, value, ishash=False): self.assertEqual( self.get(address, ishash), value, 'Wrong value') def assert_no_key(self, key): with self.assertRaises(KeyError): self.get(key) def assert_root(self, expected): self.assertEqual( expected, self.get_merkle_root(), 'Wrong root') def assert_not_root(self, *not_roots): root = self.get_merkle_root() for not_root in not_roots: self.assertNotEqual( root, not_root, 'Wrong root') # trie accessors # For convenience, assume keys are not hashed # unless otherwise indicated. def set(self, key, val, ishash=False): key_ = key if ishash else _hash(key) return self.trie.set(key_, val) def get(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.get(key_) def delete(self, key, ishash=False): key_ = key if ishash else _hash(key) return self.trie.delete(key_) def set_merkle_root(self, root): self.trie.set_merkle_root(root) def get_merkle_root(self): return self.trie.get_merkle_root() def update(self, set_items, delete_items=None, virtual=True): return self.trie.update(set_items, delete_items, virtual=virtual)