def test_update_index_no_existing_index(self): # Test that calling update_index with no existing index acts like # building the index fresh. This test is basically the same as # test_build_index_fresh_build but using update_index instead. descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) descriptors = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors: d.set_vector(np.ones(1, float) * d.uuid()) index.update_index(descriptors) # Make sure descriptors are now in attached index and in key-value-store self.assertEqual(descr_set.count(), 5) for d in descriptors: self.assertIn(d, descr_set) # Dummy hash function bins sum of descriptor vectors. self.assertEqual(hash_kvs.count(), 5) for i in range(5): self.assertSetEqual(hash_kvs.get(i), {i})
def test_build_index_fresh_build(self): descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) descriptors = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] # Vectors of length 1 for easy dummy hashing prediction. for i, d in enumerate(descriptors): d.set_vector(np.ones(1, float) * i) index.build_index(descriptors) # Make sure descriptors are now in attached index and in # key-value-store. self.assertEqual(descr_set.count(), 5) for d in descriptors: self.assertIn(d, descr_set) # Dummy hash function bins sum of descriptor vectors. self.assertEqual(hash_kvs.count(), 5) for i in range(5): self.assertSetEqual(hash_kvs.get(i), {i})
def test_count_empty_hash2uid(self): """ Test that an empty hash-to-uid mapping results in a 0 return regardless of descriptor-set state. """ descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() self.assertEqual(descr_set.count(), 0) self.assertEqual(hash_kvs.count(), 0) lsh = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) self.assertEqual(lsh.count(), 0) # Additions to the descriptor-set should not impact LSH index "size" lsh.descriptor_set.add_descriptor(DescriptorMemoryElement('t', 0)) self.assertEqual(lsh.descriptor_set.count(), 1) self.assertEqual(lsh.hash2uuids_kvstore.count(), 0) self.assertEqual(lsh.count(), 0) lsh.descriptor_set.add_descriptor(DescriptorMemoryElement('t', 1)) self.assertEqual(lsh.descriptor_set.count(), 2) self.assertEqual(lsh.hash2uuids_kvstore.count(), 0) self.assertEqual(lsh.count(), 0) lsh.hash2uuids_kvstore.add(0, {0}) self.assertEqual(lsh.descriptor_set.count(), 2) self.assertEqual(lsh.count(), 1) lsh.hash2uuids_kvstore.add(0, {0, 1}) self.assertEqual(lsh.descriptor_set.count(), 2) self.assertEqual(lsh.count(), 2) lsh.hash2uuids_kvstore.add(0, {0, 1, 2}) self.assertEqual(lsh.descriptor_set.count(), 2) self.assertEqual(lsh.count(), 3)
def test_cache_table_empty_table(self): inst = MemoryDescriptorSet(DataMemoryElement(), -1) inst._table = {} expected_table_pickle_bytes = pickle.dumps(inst._table, -1) inst.cache_table() self.assertIsNotNone(inst.cache_element) self.assertEqual(inst.cache_element.get_bytes(), expected_table_pickle_bytes)
def test_from_config_null_cache_elem(self): inst = MemoryDescriptorSet.from_config({'cache_element': None}) self.assertIsNone(inst.cache_element) self.assertEqual(inst._table, {}) inst = MemoryDescriptorSet.from_config( {'cache_element': { 'type': None }}) self.assertIsNone(inst.cache_element) self.assertEqual(inst._table, {})
def __init__(self, json_config): super(SmqtkClassifierService, self).__init__(json_config) self.enable_classifier_removal = \ bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL]) self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS]) # Convert configuration into SMQTK plugin instances. # - Static classifier configurations. # - Skip the example config key # - Classification element factory # - Descriptor generator # - Descriptor element factory # - from-IQR-state classifier configuration # - There must at least be the default key defined for when no # specific classifier type is specified at state POST. # Classifier collection + factor self.classification_factory = \ ClassificationElementFactory.from_config( json_config[self.CONFIG_CLASSIFICATION_FACTORY] ) #: :type: ClassifierCollection self.classifier_collection = ClassifierCollection.from_config( json_config[self.CONFIG_CLASSIFIER_COLLECTION] ) # Descriptor generator + factory self.descriptor_factory = DescriptorElementFactory.from_config( json_config[self.CONFIG_DESCRIPTOR_FACTORY] ) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_gen = from_config_dict( json_config[self.CONFIG_DESCRIPTOR_GENERATOR], smqtk.algorithms.DescriptorGenerator.get_impls() ) # Descriptor set bundled for classification-by-UID. try: self.descriptor_set = from_config_dict( json_config.get(self.CONFIG_DESCRIPTOR_SET, {}), DescriptorSet.get_impls() ) except ValueError: # Default empty set. self.descriptor_set = MemoryDescriptorSet() # Classifier config for uploaded IQR states. self.iqr_state_classifier_config = \ json_config[self.CONFIG_IQR_CLASSIFIER] self.add_routes()
def test_remove_from_index_shared_hashes_partial(self): """ Test that only some hashes are removed from the hash index, but not others when those hashes still refer to other descriptors. """ # Simulate initial state with some descriptor hashed to one value and # other descriptors hashed to another. # Vectors of length 1 for easy dummy hashing prediction. descriptors = [ DescriptorMemoryElement('t', 0).set_vector([0]), DescriptorMemoryElement('t', 1).set_vector([1]), DescriptorMemoryElement('t', 2).set_vector([2]), DescriptorMemoryElement('t', 3).set_vector([3]), DescriptorMemoryElement('t', 4).set_vector([4]), ] # Dummy hash function to do the simulated thing hash_func = DummyHashFunctor() hash_func.get_hash = mock.Mock( # Vectors of even sum hash to 0, odd to 1. side_effect=lambda vec: [vec.sum() % 2]) d_set = MemoryDescriptorSet() d_set._table = { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 3: descriptors[3], 4: descriptors[4], } hash2uid_kvs = MemoryKeyValueStore() hash2uid_kvs._table = { 0: {0, 2, 4}, 1: {1, 3}, } idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uid_kvs) idx.hash_index = mock.Mock(spec=HashIndex) idx.remove_from_index([1, 2, 3]) # Check that only one hash vector was passed to hash_index's removal # method (deque of hash-code vectors). idx.hash_index.remove_from_index.assert_called_once_with( collections.deque([ [1], ])) self.assertDictEqual(d_set._table, { 0: descriptors[0], 4: descriptors[4], }) self.assertDictEqual(hash2uid_kvs._table, {0: {0, 4}})
def test_added_descriptor_table_caching(self): cache_elem = DataMemoryElement(readonly=False) descrs = [random_descriptor() for _ in range(3)] expected_table = dict((r.uuid(), r) for r in descrs) i = MemoryDescriptorSet(cache_elem) self.assertTrue(cache_elem.is_empty()) # Should add descriptors to table, caching to writable element. i.add_many_descriptors(descrs) self.assertFalse(cache_elem.is_empty()) self.assertEqual(pickle.loads(i.cache_element.get_bytes()), expected_table) # Changing the internal table (remove, add) it should reflect in # cache new_d = random_descriptor() expected_table[new_d.uuid()] = new_d i.add_descriptor(new_d) self.assertEqual(pickle.loads(i.cache_element.get_bytes()), expected_table) rm_d = list(expected_table.values())[0] del expected_table[rm_d.uuid()] i.remove_descriptor(rm_d.uuid()) self.assertEqual(pickle.loads(i.cache_element.get_bytes()), expected_table)
def test_get_config(self): self.assertEqual(MemoryDescriptorSet().get_config(), MemoryDescriptorSet.get_default_config()) self.assertEqual( MemoryDescriptorSet(None).get_config(), MemoryDescriptorSet.get_default_config()) empty_elem = DataMemoryElement() dme_key = 'smqtk.representation.data_element.memory_element.DataMemoryElement' self.assertEqual( MemoryDescriptorSet(empty_elem).get_config(), merge_dict(MemoryDescriptorSet.get_default_config(), {'cache_element': { 'type': dme_key }})) dict_pickle_bytes = pickle.dumps({1: 1, 2: 2, 3: 3}, -1) dict_pickle_bytes_str = dict_pickle_bytes.decode(BYTES_CONFIG_ENCODING) cache_elem = DataMemoryElement(bytes=dict_pickle_bytes) self.assertEqual( MemoryDescriptorSet(cache_elem).get_config(), merge_dict( MemoryDescriptorSet.get_default_config(), { 'cache_element': { dme_key: { 'bytes': dict_pickle_bytes_str }, 'type': dme_key } }))
def test_update_index_existing_descriptors_frozenset(self): """ Same as ``test_update_index_similar_descriptors`` but testing that we can update the index when seeded with structures with existing values. """ # Similar Descriptors to build and update on (different instances) descriptors1 = [ DescriptorMemoryElement('t', 0).set_vector([0]), DescriptorMemoryElement('t', 1).set_vector([1]), DescriptorMemoryElement('t', 2).set_vector([2]), DescriptorMemoryElement('t', 3).set_vector([3]), DescriptorMemoryElement('t', 4).set_vector([4]), ] descriptors2 = [ DescriptorMemoryElement('t', 5).set_vector([0]), DescriptorMemoryElement('t', 6).set_vector([1]), DescriptorMemoryElement('t', 7).set_vector([2]), DescriptorMemoryElement('t', 8).set_vector([3]), DescriptorMemoryElement('t', 9).set_vector([4]), ] descr_set = MemoryDescriptorSet() descr_set.add_many_descriptors(descriptors1) hash_kvs = MemoryKeyValueStore() hash_kvs.add(0, frozenset({0})) hash_kvs.add(1, frozenset({1})) hash_kvs.add(2, frozenset({2})) hash_kvs.add(3, frozenset({3})) hash_kvs.add(4, frozenset({4})) index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) index.update_index(descriptors2) assert descr_set.count() == 10 # Above descriptors should be considered "in" the descriptor set now. for d in descriptors1: assert d in descr_set for d in descriptors2: assert d in descr_set # Known hashes of the above descriptors should be in the KVS assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4} assert hash_kvs.get(0) == {0, 5} assert hash_kvs.get(1) == {1, 6} assert hash_kvs.get(2) == {2, 7} assert hash_kvs.get(3) == {3, 8} assert hash_kvs.get(4) == {4, 9}
def test_configuration(self): ex_descr_set = MemoryDescriptorSet() ex_i2u_kvs = MemoryKeyValueStore() ex_u2i_kvs = MemoryKeyValueStore() ex_index_elem = DataMemoryElement() ex_index_param_elem = DataMemoryElement() i = FaissNearestNeighborsIndex( descriptor_set=ex_descr_set, idx2uid_kvs=ex_i2u_kvs, uid2idx_kvs=ex_u2i_kvs, index_element=ex_index_elem, index_param_element=ex_index_param_elem, read_only=True, factory_string=u'some fact str', ivf_nprobe=88, use_gpu=False, gpu_id=99, random_seed=8, ) for inst in configuration_test_helper(i): assert isinstance(inst._descriptor_set, MemoryDescriptorSet) assert isinstance(inst._idx2uid_kvs, MemoryKeyValueStore) assert isinstance(inst._uid2idx_kvs, MemoryKeyValueStore) assert isinstance(inst._index_element, DataMemoryElement) assert isinstance(inst._index_param_element, DataMemoryElement) assert inst.read_only is True assert isinstance(inst.factory_string, six.string_types) assert inst.factory_string == 'some fact str' assert inst._ivf_nprobe == 88 assert inst._use_gpu is False assert inst._gpu_id == 99 assert inst.random_seed == 8
def _make_inst(self, **kwargs): """ Make an instance of MRPTNearestNeighborsIndex """ if 'random_seed' not in kwargs: kwargs.update(random_seed=self.RAND_SEED) return MRPTNearestNeighborsIndex(MemoryDescriptorSet(), **kwargs)
def test_nn_pathological_example(self): n = 10**4 dim = 256 depth = 10 # L ~ n/2**depth = 10^4 / 2^10 ~ 10 k = 200 # 3k/L = 60 num_trees = 60 d_set = [DescriptorMemoryElement('test', i) for i in range(n)] # Put all descriptors on a line so that different trees get same # divisions. # noinspection PyTypeChecker [d.set_vector(np.full(dim, d.uuid(), dtype=np.float64)) for d in d_set] q = DescriptorMemoryElement('q', -1) q.set_vector(np.zeros((dim, ))) di = MemoryDescriptorSet() mrpt = MRPTNearestNeighborsIndex(di, num_trees=num_trees, depth=depth, random_seed=0) mrpt.build_index(d_set) nbrs, dists = mrpt.nn(q, k) self.assertEqual(len(nbrs), len(dists)) # We should get about 10 descriptors back instead of the requested # 200 self.assertLess(len(nbrs), 20)
def test_remove_from_index(self): # Test that removing by UIDs does the correct thing. # Descriptors are 1 dim, value == index. descriptors = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors: d.set_vector(np.ones(1, float) * d.uuid()) d_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs) idx.build_index(descriptors) # Attempt removing 1 uid. idx.remove_from_index([3]) self.assertEqual( idx.descriptor_set._table, { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 4: descriptors[4], }) self.assertEqual(idx.hash2uuids_kvstore._table, { 0: {0}, 1: {1}, 2: {2}, 4: {4}, })
def test_configuration(self): index_filepath = osp.abspath(osp.expanduser('index_filepath')) para_filepath = osp.abspath(osp.expanduser('param_fp')) i = MRPTNearestNeighborsIndex( descriptor_set=MemoryDescriptorSet(), index_filepath=index_filepath, parameters_filepath=para_filepath, read_only=True, num_trees=9, depth=2, random_seed=8, pickle_protocol=0, use_multiprocessing=True, ) for inst in configuration_test_helper( i): # type: MRPTNearestNeighborsIndex assert isinstance(inst._descriptor_set, MemoryDescriptorSet) assert inst._index_filepath == index_filepath assert inst._index_param_filepath == para_filepath assert inst._read_only == True assert inst._num_trees == 9 assert inst._depth == 2 assert inst._rand_seed == 8 assert inst._pickle_protocol == 0 assert inst._use_multiprocessing == True
def test_nn_small_leaves(self): np.random.seed(0) n = 10**4 dim = 256 depth = 10 # L ~ n/2**depth = 10^4 / 2^10 ~ 10 k = 200 # 3k/L = 60 num_trees = 60 d_set = [DescriptorMemoryElement('test', i) for i in range(n)] [d.set_vector(np.random.rand(dim)) for d in d_set] q = DescriptorMemoryElement('q', -1) q.set_vector(np.zeros((dim, ))) di = MemoryDescriptorSet() mrpt = MRPTNearestNeighborsIndex(di, num_trees=num_trees, depth=depth, random_seed=0) mrpt.build_index(d_set) nbrs, dists = mrpt.nn(q, k) self.assertEqual(len(nbrs), len(dists)) self.assertEqual(len(nbrs), k)
def test_update_index_with_hash_index(self): # Similar test to `test_update_index_add_new_descriptors` but with a # linear hash index. descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() linear_hi = LinearHashIndex() # simplest hash index, heap-sorts. index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs, linear_hi) descriptors1 = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] descriptors2 = [ DescriptorMemoryElement('t', 5), DescriptorMemoryElement('t', 6), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors1 + descriptors2: d.set_vector(np.ones(1, float) * d.uuid()) # Build initial index. index.build_index(descriptors1) # Initial hash index should only encode hashes for first batch of # descriptors. self.assertSetEqual(linear_hi.index, {0, 1, 2, 3, 4}) # Update index and check that components have new data. index.update_index(descriptors2) # Now the hash index should include all descriptor hashes. self.assertSetEqual(linear_hi.index, {0, 1, 2, 3, 4, 5, 6})
def _random_euclidean(self, hash_ftor, hash_idx, ftor_train_hook=lambda d: None): # :param hash_ftor: Hash function class for generating hash codes for # descriptors. # :param hash_idx: Hash index instance to use in local LSH algo # instance. # :param ftor_train_hook: Function for training functor if necessary. # make random descriptors i = 1000 dim = 256 td = [] np.random.seed(self.RANDOM_SEED) for j in range(i): d = DescriptorMemoryElement('random', j) d.set_vector(np.random.rand(dim)) td.append(d) ftor_train_hook(td) di = MemoryDescriptorSet() kvstore = MemoryKeyValueStore() index = LSHNearestNeighborIndex(hash_ftor, di, kvstore, hash_index=hash_idx, distance_method='euclidean') index.build_index(td) # test query from build set -- should return same descriptor when k=1 q = td[255] r, dists = index.nn(q, 1) self.assertEqual(r[0], q) # test query very near a build vector td_q = td[0] q = DescriptorMemoryElement('query', i) v = td_q.vector().copy() v_min = max(v.min(), 0.1) v[0] += v_min v[dim - 1] -= v_min q.set_vector(v) r, dists = index.nn(q, 1) self.assertFalse(np.array_equal(q.vector(), td_q.vector())) self.assertEqual(r[0], td_q) # random query q = DescriptorMemoryElement('query', i + 1) q.set_vector(np.random.rand(dim)) # for any query of size k, results should at least be in distance order r, dists = index.nn(q, 10) for j in range(1, len(dists)): self.assertGreater(dists[j], dists[j - 1]) r, dists = index.nn(q, i) for j in range(1, len(dists)): self.assertGreater(dists[j], dists[j - 1])
def test_clustering_equal_descriptors(self): # Test that clusters of descriptor of size n-features are correctly # clustered together. print("Creating dummy descriptors") n_features = 8 n_descriptors = 20 desr_set = MemoryDescriptorSet() c = 0 for i in range(n_features): v = numpy.ndarray((8, )) v[...] = 0 v[i] = 1 for j in range(n_descriptors): d = DescriptorMemoryElement('test', c) d.set_vector(v) desr_set.add_descriptor(d) c += 1 print("Creating test MBKM") mbkm = MiniBatchKMeans(n_features, batch_size=12, verbose=True, compute_labels=False, random_state=0) # Initial fit with half of desr_set d_classes = mb_kmeans_build_apply(desr_set, mbkm, n_descriptors) # There should be 20 descriptors per class for c in d_classes: self.assertEqual( len(d_classes[c]), n_descriptors, "Cluster %s did not have expected number of descriptors " "(%d != %d)" % (c, n_descriptors, len(d_classes[c]))) # Each descriptor in each cluster should be equal to the other # descriptors in that cluster uuids = list(d_classes[c]) v = desr_set[uuids[0]].vector() for uuid in uuids[1:]: v2 = desr_set[uuid].vector() numpy.testing.assert_array_equal( v, v2, "vector in cluster %d did not " "match other vectors " "(%s != %s)" % (c, v, v2))
def test_remove_from_index_read_only(self): d_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs, read_only=True) self.assertRaises(ReadOnlyError, idx.remove_from_index, ['uid1', 'uid2'])
def test_remove_from_index_no_existing_index(self): # Test that attempting to remove from an instance with no existing # index (meaning empty descriptor-set and key-value-store) results in # a key error. d_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs) self.assertRaisesRegex(KeyError, 'uid1', idx.remove_from_index, ['uid1'])
def test_update_index_add_new_descriptors(self): # Test that calling update index after a build index causes index # components to be properly updated. descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) descriptors1 = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] descriptors2 = [ DescriptorMemoryElement('t', 5), DescriptorMemoryElement('t', 6), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors1 + descriptors2: d.set_vector(np.ones(1, float) * d.uuid()) # Build initial index. index.build_index(descriptors1) self.assertEqual(descr_set.count(), 5) for d in descriptors1: self.assertIn(d, descr_set) for d in descriptors2: self.assertNotIn(d, descr_set) # Dummy hash function bins sum of descriptor vectors. self.assertEqual(hash_kvs.count(), 5) for i in range(5): self.assertSetEqual(hash_kvs.get(i), {i}) # Update index and check that components have new data. index.update_index(descriptors2) self.assertEqual(descr_set.count(), 7) for d in descriptors1 + descriptors2: self.assertIn(d, descr_set) # Dummy hash function bins sum of descriptor vectors. self.assertEqual(hash_kvs.count(), 7) for i in range(7): self.assertSetEqual(hash_kvs.get(i), {i})
def test_update_index_similar_descriptors(self): """ Test that updating a built index with similar descriptors (same vectors, different UUIDs) results in contained structures having an expected state. """ descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) # Similar Descriptors to build and update on (different instances) descriptors1 = [ DescriptorMemoryElement('t', 0).set_vector([0]), DescriptorMemoryElement('t', 1).set_vector([1]), DescriptorMemoryElement('t', 2).set_vector([2]), DescriptorMemoryElement('t', 3).set_vector([3]), DescriptorMemoryElement('t', 4).set_vector([4]), ] descriptors2 = [ DescriptorMemoryElement('t', 5).set_vector([0]), DescriptorMemoryElement('t', 6).set_vector([1]), DescriptorMemoryElement('t', 7).set_vector([2]), DescriptorMemoryElement('t', 8).set_vector([3]), DescriptorMemoryElement('t', 9).set_vector([4]), ] index.build_index(descriptors1) index.update_index(descriptors2) assert descr_set.count() == 10 # Above descriptors should be considered "in" the descriptor set now. for d in descriptors1: assert d in descr_set for d in descriptors2: assert d in descr_set # Known hashes of the above descriptors should be in the KVS assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4} assert hash_kvs.get(0) == {0, 5} assert hash_kvs.get(1) == {1, 6} assert hash_kvs.get(2) == {2, 7} assert hash_kvs.get(3) == {3, 8} assert hash_kvs.get(4) == {4, 9}
def test_update_index_duplicate_descriptors(self): """ Test that updating a built index with the same descriptors results in idempotent behavior. """ descr_set = MemoryDescriptorSet() hash_kvs = MemoryKeyValueStore() index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs) # Identical Descriptors to build and update on (different instances) descriptors1 = [ DescriptorMemoryElement('t', 0).set_vector([0]), DescriptorMemoryElement('t', 1).set_vector([1]), DescriptorMemoryElement('t', 2).set_vector([2]), DescriptorMemoryElement('t', 3).set_vector([3]), DescriptorMemoryElement('t', 4).set_vector([4]), ] descriptors2 = [ DescriptorMemoryElement('t', 0).set_vector([0]), DescriptorMemoryElement('t', 1).set_vector([1]), DescriptorMemoryElement('t', 2).set_vector([2]), DescriptorMemoryElement('t', 3).set_vector([3]), DescriptorMemoryElement('t', 4).set_vector([4]), ] index.build_index(descriptors1) index.update_index(descriptors2) assert descr_set.count() == 5 # Above descriptors should be considered "in" the descriptor set now. for d in descriptors1: assert d in descr_set for d in descriptors2: assert d in descr_set # Known hashes of the above descriptors should be in the KVS assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4} assert hash_kvs.get(0) == {0} assert hash_kvs.get(1) == {1} assert hash_kvs.get(2) == {2} assert hash_kvs.get(3) == {3} assert hash_kvs.get(4) == {4}
def test_init_with_cache(self): d_list = (random_descriptor(), random_descriptor(), random_descriptor(), random_descriptor()) expected_table = dict((r.uuid(), r) for r in d_list) expected_cache = DataMemoryElement(bytes=pickle.dumps(expected_table)) inst = MemoryDescriptorSet(expected_cache) self.assertEqual(len(inst._table), 4) self.assertEqual(inst.cache_element, expected_cache) self.assertEqual(inst._table, expected_table) self.assertEqual(set(inst._table.values()), set(d_list))
def test_add_many(self): descrs = [ random_descriptor(), random_descriptor(), random_descriptor(), random_descriptor(), random_descriptor(), ] index = MemoryDescriptorSet() index.add_many_descriptors(descrs) # Compare code keys of input to code keys in internal table self.assertEqual(set(index._table.keys()), set([e.uuid() for e in descrs])) # Get the set of descriptors in the internal table and compare it with # the set of generated random descriptors. r_set = set() [r_set.add(d) for d in index._table.values()] self.assertEqual(set([e for e in descrs]), r_set)
def test_remove_from_index_shared_hashes(self): """ Test that removing a descriptor (by UID) that shares a hash with other descriptors does not trigger removal of its hash. """ # Simulate descriptors all hashing to the same hash value: 0 hash_func = DummyHashFunctor() hash_func.get_hash = mock.Mock(return_value=np.asarray([0], bool)) d_set = MemoryDescriptorSet() hash2uids_kvs = MemoryKeyValueStore() idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uids_kvs) # Descriptors are 1 dim, value == index. descriptors = [ DescriptorMemoryElement('t', 0), DescriptorMemoryElement('t', 1), DescriptorMemoryElement('t', 2), DescriptorMemoryElement('t', 3), DescriptorMemoryElement('t', 4), ] # Vectors of length 1 for easy dummy hashing prediction. for d in descriptors: d.set_vector(np.ones(1, float) * d.uuid()) idx.build_index(descriptors) # We expect the descriptor-set and kvs to look like the following now: self.assertDictEqual( d_set._table, { 0: descriptors[0], 1: descriptors[1], 2: descriptors[2], 3: descriptors[3], 4: descriptors[4], }) self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 2, 3, 4}}) # Mock out hash index as if we had an implementation so we can check # call to its remove_from_index method. idx.hash_index = mock.Mock(spec=HashIndex) idx.remove_from_index([2, 4]) # Only uid 2 and 4 descriptors should be gone from d-set, kvs should # still have the 0 key and its set value should only contain uids 0, 1 # and 3. `hash_index.remove_from_index` should not be called because # no hashes should be marked for removal. self.assertDictEqual(d_set._table, { 0: descriptors[0], 1: descriptors[1], 3: descriptors[3], }) self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 3}}) idx.hash_index.remove_from_index.assert_not_called()
def test_has(self): i = MemoryDescriptorSet() descrs = [random_descriptor() for _ in range(10)] i.add_many_descriptors(descrs) self.assertTrue(i.has_descriptor(descrs[4].uuid())) self.assertFalse(i.has_descriptor('not_an_int'))
def test_from_config_null_cache_elem_type(self): # An empty cache should not trigger loading on construction. expected_empty_cache = DataMemoryElement() dme_key = 'smqtk.representation.data_element.memory_element.DataMemoryElement' inst = MemoryDescriptorSet.from_config( {'cache_element': { 'type': dme_key, dme_key: { 'bytes': '' } }}) self.assertEqual(inst.cache_element, expected_empty_cache) self.assertEqual(inst._table, {})
def test_from_config_null_cache_elem_type(self): # An empty cache should not trigger loading on construction. expected_empty_cache = DataMemoryElement() inst = MemoryDescriptorSet.from_config({ 'cache_element': { 'type': 'DataMemoryElement', 'DataMemoryElement': { 'bytes': '' } } }) self.assertEqual(inst.cache_element, expected_empty_cache) self.assertEqual(inst._table, {})