def test_remove_from_index(self) -> None: # Test that removing by UIDs does the correct thing. # Descriptors are 1 dim, value == index. descriptors = [ DescriptorMemoryElement(0), DescriptorMemoryElement(1), DescriptorMemoryElement(2), DescriptorMemoryElement(3), DescriptorMemoryElement(4), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors: d.set_vector(np.ones(1, float) * d.uuid()) d_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs) idx.build_index(descriptors) # Attempt removing 1 uid. idx.remove_from_index([3]) assert isinstance(idx.descriptor_set, MemoryDescriptorSet) self.assertEqual(idx.descriptor_set._table, { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 4: descriptors[4], }) assert isinstance(idx.hash2uuids_kvstore, MemoryKeyValueStore) self.assertEqual(idx.hash2uuids_kvstore._table, { 0: {0}, 1: {1}, 2: {2}, 4: {4}, })
def test_remove_from_index_shared_hashes_partial(self) -> None: """ Test that only some hashes are removed from the hash index, but not others when those hashes still refer to other descriptors. """ # Simulate initial state with some descriptor hashed to one value and # other descriptors hashed to another. # Vectors of length 1 for easy dummy hashing prediction. descriptors = [ DescriptorMemoryElement(0).set_vector([0]), DescriptorMemoryElement(1).set_vector([1]), DescriptorMemoryElement(2).set_vector([2]), DescriptorMemoryElement(3).set_vector([3]), DescriptorMemoryElement(4).set_vector([4]), ] # Dummy hash function to do the simulated thing hash_func = DummyHashFunctor() hash_func.get_hash = mock.Mock( # type: ignore # Vectors of even sum hash to 0, odd to 1. side_effect=lambda vec: [vec.sum() % 2] ) d_set = MemoryDescriptorSet() d_set._table = { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 3: descriptors[3], 4: descriptors[4], } hash2uid_kvs = MemoryKeyValueStore() hash2uid_kvs._table = { 0: {0, 2, 4}, 1: {1, 3}, } idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uid_kvs) idx.hash_index = mock.Mock(spec=HashIndex) idx.remove_from_index([1, 2, 3]) # Check that only one hash vector was passed to hash_index's removal # method (deque of hash-code vectors). idx.hash_index.remove_from_index.assert_called_once_with( collections.deque([ [1], ]) ) self.assertDictEqual(d_set._table, { 0: descriptors[0], 4: descriptors[4], }) self.assertDictEqual(hash2uid_kvs._table, {0: {0, 4}})
def test_remove_from_index_shared_hashes(self) -> None: """ Test that removing a descriptor (by UID) that shares a hash with other descriptors does not trigger removal of its hash. """ # Simulate descriptors all hashing to the same hash value: 0 hash_func = DummyHashFunctor() hash_func.get_hash = mock.Mock(return_value=np.asarray( [0], bool)) # type: ignore d_set = MemoryDescriptorSet() hash2uids_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uids_kvs) # Descriptors are 1 dim, value == index. descriptors = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors: d.set_vector(np.ones(1, float) * d.uuid()) idx.build_index(descriptors) # We expect the descriptor-set and kvs to look like the following now: self.assertDictEqual( d_set._table, { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 3: descriptors[3], 4: descriptors[4], }) self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 2, 3, 4}}) # Mock out hash index as if we had an implementation so we can check # call to its remove_from_index method. idx.hash_index = mock.Mock(spec=HashIndex) idx.remove_from_index([2, 4]) # Only uid 2 and 4 descriptors should be gone from d-set, kvs should # still have the 0 key and its set value should only contain uids 0, 1 # and 3. `hash_index.remove_from_index` should not be called because # no hashes should be marked for removal. self.assertDictEqual(d_set._table, { 0: descriptors[0], 1: descriptors[1], 3: descriptors[3], }) self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 3}}) idx.hash_index.remove_from_index.assert_not_called()