Esempio n. 1
0
 def test_natural_iter(self) -> None:
     """Test that iterating over the descriptor set appropriately
     yields the descriptor element contents."""
     i = MemoryDescriptorSet()
     descrs = [random_descriptor() for _ in range(100)]
     i.add_many_descriptors(descrs)
     self.assertEqual(set(i), set(descrs))
Esempio n. 2
0
    def test_build_index_fresh_build(self) -> None:
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)

        descriptors = [
            DescriptorMemoryElement('t', 0),
            DescriptorMemoryElement('t', 1),
            DescriptorMemoryElement('t', 2),
            DescriptorMemoryElement('t', 3),
            DescriptorMemoryElement('t', 4),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for i, d in enumerate(descriptors):
            d.set_vector(np.ones(1, float) * i)
        index.build_index(descriptors)

        # Make sure descriptors are now in attached index and in
        # key-value-store.
        self.assertEqual(descr_set.count(), 5)
        for d in descriptors:
            self.assertIn(d, descr_set)
        # Dummy hash function bins sum of descriptor vectors.
        self.assertEqual(hash_kvs.count(), 5)
        for i in range(5):
            self.assertSetEqual(hash_kvs.get(i), {i})
Esempio n. 3
0
    def test_update_index_no_existing_index(self) -> None:
        # Test that calling update_index with no existing index acts like
        # building the index fresh.  This test is basically the same as
        # test_build_index_fresh_build but using update_index instead.
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)

        descriptors = [
            DescriptorMemoryElement('t', 0),
            DescriptorMemoryElement('t', 1),
            DescriptorMemoryElement('t', 2),
            DescriptorMemoryElement('t', 3),
            DescriptorMemoryElement('t', 4),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors:
            d.set_vector(np.ones(1, float) * d.uuid())
        index.update_index(descriptors)

        # Make sure descriptors are now in attached index and in key-value-store
        self.assertEqual(descr_set.count(), 5)
        for d in descriptors:
            self.assertIn(d, descr_set)
        # Dummy hash function bins sum of descriptor vectors.
        self.assertEqual(hash_kvs.count(), 5)
        for i in range(5):
            self.assertSetEqual(hash_kvs.get(i), {i})
Esempio n. 4
0
    def test_count_empty_hash2uid(self) -> None:
        """
        Test that an empty hash-to-uid mapping results in a 0 return regardless
        of descriptor-set state.
        """
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        self.assertEqual(descr_set.count(), 0)
        self.assertEqual(hash_kvs.count(), 0)

        lsh = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set, hash_kvs)
        self.assertEqual(lsh.count(), 0)

        # Additions to the descriptor-set should not impact LSH index "size"
        lsh.descriptor_set.add_descriptor(DescriptorMemoryElement('t', 0))
        self.assertEqual(lsh.descriptor_set.count(), 1)
        self.assertEqual(lsh.hash2uuids_kvstore.count(), 0)
        self.assertEqual(lsh.count(), 0)

        lsh.descriptor_set.add_descriptor(DescriptorMemoryElement('t', 1))
        self.assertEqual(lsh.descriptor_set.count(), 2)
        self.assertEqual(lsh.hash2uuids_kvstore.count(), 0)
        self.assertEqual(lsh.count(), 0)

        lsh.hash2uuids_kvstore.add(0, {0})
        self.assertEqual(lsh.descriptor_set.count(), 2)
        self.assertEqual(lsh.count(), 1)

        lsh.hash2uuids_kvstore.add(0, {0, 1})
        self.assertEqual(lsh.descriptor_set.count(), 2)
        self.assertEqual(lsh.count(), 2)

        lsh.hash2uuids_kvstore.add(0, {0, 1, 2})
        self.assertEqual(lsh.descriptor_set.count(), 2)
        self.assertEqual(lsh.count(), 3)
Esempio n. 5
0
    def test_cache_table_empty_table(self) -> None:
        inst = MemoryDescriptorSet(DataMemoryElement(), -1)
        inst._table = {}
        expected_table_pickle_bytes = pickle.dumps(inst._table, -1)

        inst.cache_table()
        assert inst.cache_element is not None
        self.assertEqual(inst.cache_element.get_bytes(),
                         expected_table_pickle_bytes)
Esempio n. 6
0
    def test_remove_from_index_shared_hashes_partial(self) -> None:
        """
        Test that only some hashes are removed from the hash index, but not
        others when those hashes still refer to other descriptors.
        """
        # Simulate initial state with some descriptor hashed to one value and
        # other descriptors hashed to another.

        # Vectors of length 1 for easy dummy hashing prediction.
        descriptors = [
            DescriptorMemoryElement(0).set_vector([0]),
            DescriptorMemoryElement(1).set_vector([1]),
            DescriptorMemoryElement(2).set_vector([2]),
            DescriptorMemoryElement(3).set_vector([3]),
            DescriptorMemoryElement(4).set_vector([4]),
        ]

        # Dummy hash function to do the simulated thing
        hash_func = DummyHashFunctor()
        hash_func.get_hash = mock.Mock(  # type: ignore
            # Vectors of even sum hash to 0, odd to 1.
            side_effect=lambda vec: [vec.sum() % 2]
        )

        d_set = MemoryDescriptorSet()
        d_set._table = {
            0: descriptors[0],
            1: descriptors[1],
            2: descriptors[2],
            3: descriptors[3],
            4: descriptors[4],
        }

        hash2uid_kvs = MemoryKeyValueStore()
        hash2uid_kvs._table = {
            0: {0, 2, 4},
            1: {1, 3},
        }

        idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uid_kvs)
        idx.hash_index = mock.Mock(spec=HashIndex)

        idx.remove_from_index([1, 2, 3])
        # Check that only one hash vector was passed to hash_index's removal
        # method (deque of hash-code vectors).
        idx.hash_index.remove_from_index.assert_called_once_with(
            collections.deque([
                [1],
            ])
        )
        self.assertDictEqual(d_set._table, {
            0: descriptors[0],
            4: descriptors[4],
        })
        self.assertDictEqual(hash2uid_kvs._table, {0: {0, 4}})
Esempio n. 7
0
    def test_from_config_null_cache_elem(self) -> None:
        inst = MemoryDescriptorSet.from_config({'cache_element': None})
        self.assertIsNone(inst.cache_element)
        self.assertEqual(inst._table, {})

        inst = MemoryDescriptorSet.from_config(
            {'cache_element': {
                'type': None
            }})
        self.assertIsNone(inst.cache_element)
        self.assertEqual(inst._table, {})
Esempio n. 8
0
    def test_added_descriptor_table_caching(self) -> None:
        cache_elem = DataMemoryElement(readonly=False)
        descrs = [random_descriptor() for _ in range(3)]
        expected_table = dict((r.uuid(), r) for r in descrs)

        i = MemoryDescriptorSet(cache_elem)
        assert i.cache_element is not None
        self.assertTrue(cache_elem.is_empty())

        # Should add descriptors to table, caching to writable element.
        i.add_many_descriptors(descrs)
        self.assertFalse(cache_elem.is_empty())
        self.assertEqual(pickle.loads(i.cache_element.get_bytes()),
                         expected_table)

        # Changing the internal table (remove, add) it should reflect in
        # cache
        new_d = random_descriptor()
        expected_table[new_d.uuid()] = new_d
        i.add_descriptor(new_d)
        self.assertEqual(pickle.loads(i.cache_element.get_bytes()),
                         expected_table)

        rm_d = list(expected_table.values())[0]
        del expected_table[rm_d.uuid()]
        i.remove_descriptor(rm_d.uuid())
        self.assertEqual(pickle.loads(i.cache_element.get_bytes()),
                         expected_table)
Esempio n. 9
0
    def test_get_config(self) -> None:
        self.assertEqual(MemoryDescriptorSet().get_config(),
                         MemoryDescriptorSet.get_default_config())

        self.assertEqual(
            MemoryDescriptorSet(None).get_config(),
            MemoryDescriptorSet.get_default_config())

        empty_elem = DataMemoryElement()
        dme_key = 'smqtk_dataprovider.impls.data_element.memory.DataMemoryElement'
        self.assertEqual(
            MemoryDescriptorSet(empty_elem).get_config(),
            merge_dict(MemoryDescriptorSet.get_default_config(),
                       {'cache_element': {
                           'type': dme_key
                       }}))

        dict_pickle_bytes = pickle.dumps({1: 1, 2: 2, 3: 3}, -1)
        dict_pickle_bytes_str = dict_pickle_bytes.decode(BYTES_CONFIG_ENCODING)
        cache_elem = DataMemoryElement(bytes=dict_pickle_bytes)
        self.assertEqual(
            MemoryDescriptorSet(cache_elem).get_config(),
            merge_dict(
                MemoryDescriptorSet.get_default_config(), {
                    'cache_element': {
                        dme_key: {
                            'bytes': dict_pickle_bytes_str
                        },
                        'type': dme_key
                    }
                }))
Esempio n. 10
0
    def test_get_descriptors(self) -> None:
        descrs = [
            random_descriptor(),  # [0]
            random_descriptor(),  # [1]
            random_descriptor(),  # [2]
            random_descriptor(),  # [3]
            random_descriptor(),  # [4]
        ]
        index = MemoryDescriptorSet()
        index.add_many_descriptors(descrs)

        # single descriptor reference
        r = index.get_descriptor(descrs[1].uuid())
        self.assertEqual(r, descrs[1])
Esempio n. 11
0
    def test_update_index_existing_descriptors_frozenset(self) -> None:
        """
        Same as ``test_update_index_similar_descriptors`` but testing that
        we can update the index when seeded with structures with existing
        values.
        """
        # Similar Descriptors to build and update on (different instances)
        descriptors1 = [
            DescriptorMemoryElement('t', 0).set_vector([0]),
            DescriptorMemoryElement('t', 1).set_vector([1]),
            DescriptorMemoryElement('t', 2).set_vector([2]),
            DescriptorMemoryElement('t', 3).set_vector([3]),
            DescriptorMemoryElement('t', 4).set_vector([4]),
        ]
        descriptors2 = [
            DescriptorMemoryElement('t', 5).set_vector([0]),
            DescriptorMemoryElement('t', 6).set_vector([1]),
            DescriptorMemoryElement('t', 7).set_vector([2]),
            DescriptorMemoryElement('t', 8).set_vector([3]),
            DescriptorMemoryElement('t', 9).set_vector([4]),
        ]

        descr_set = MemoryDescriptorSet()
        descr_set.add_many_descriptors(descriptors1)

        hash_kvs = MemoryKeyValueStore()
        hash_kvs.add(0, frozenset({0}))
        hash_kvs.add(1, frozenset({1}))
        hash_kvs.add(2, frozenset({2}))
        hash_kvs.add(3, frozenset({3}))
        hash_kvs.add(4, frozenset({4}))

        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)
        index.update_index(descriptors2)

        assert descr_set.count() == 10
        # Above descriptors should be considered "in" the descriptor set now.
        for d in descriptors1:
            assert d in descr_set
        for d in descriptors2:
            assert d in descr_set
        # Known hashes of the above descriptors should be in the KVS
        assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4}
        assert hash_kvs.get(0) == {0, 5}
        assert hash_kvs.get(1) == {1, 6}
        assert hash_kvs.get(2) == {2, 7}
        assert hash_kvs.get(3) == {3, 8}
        assert hash_kvs.get(4) == {4, 9}
Esempio n. 12
0
    def test_update_index_with_hash_index(self) -> None:
        # Similar test to `test_update_index_add_new_descriptors` but with a
        # linear hash index.
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        linear_hi = LinearHashIndex()  # simplest hash index, heap-sorts.
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs, linear_hi)

        descriptors1 = [
            DescriptorMemoryElement('t', 0),
            DescriptorMemoryElement('t', 1),
            DescriptorMemoryElement('t', 2),
            DescriptorMemoryElement('t', 3),
            DescriptorMemoryElement('t', 4),
        ]
        descriptors2 = [
            DescriptorMemoryElement('t', 5),
            DescriptorMemoryElement('t', 6),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors1 + descriptors2:
            d.set_vector(np.ones(1, float) * d.uuid())

        # Build initial index.
        index.build_index(descriptors1)
        # Initial hash index should only encode hashes for first batch of
        # descriptors.
        self.assertSetEqual(linear_hi.index, {0, 1, 2, 3, 4})

        # Update index and check that components have new data.
        index.update_index(descriptors2)
        # Now the hash index should include all descriptor hashes.
        self.assertSetEqual(linear_hi.index, {0, 1, 2, 3, 4, 5, 6})
Esempio n. 13
0
    def test_remove_from_index(self) -> None:
        # Test that removing by UIDs does the correct thing.

        # Descriptors are 1 dim, value == index.
        descriptors = [
            DescriptorMemoryElement(0),
            DescriptorMemoryElement(1),
            DescriptorMemoryElement(2),
            DescriptorMemoryElement(3),
            DescriptorMemoryElement(4),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors:
            d.set_vector(np.ones(1, float) * d.uuid())
        d_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs)
        idx.build_index(descriptors)

        # Attempt removing 1 uid.
        idx.remove_from_index([3])
        assert isinstance(idx.descriptor_set, MemoryDescriptorSet)
        self.assertEqual(idx.descriptor_set._table, {
            0: descriptors[0],
            1: descriptors[1],
            2: descriptors[2],
            4: descriptors[4],
        })
        assert isinstance(idx.hash2uuids_kvstore, MemoryKeyValueStore)
        self.assertEqual(idx.hash2uuids_kvstore._table, {
            0: {0},
            1: {1},
            2: {2},
            4: {4},
        })
Esempio n. 14
0
    def test_configuration(self) -> None:
        index_filepath = osp.abspath(osp.expanduser('index_filepath'))
        para_filepath = osp.abspath(osp.expanduser('param_fp'))

        i = MRPTNearestNeighborsIndex(
            descriptor_set=MemoryDescriptorSet(),
            index_filepath=index_filepath,
            parameters_filepath=para_filepath,
            read_only=True,
            num_trees=9,
            depth=2,
            random_seed=8,
            pickle_protocol=0,
            use_multiprocessing=True,
        )
        for inst in configuration_test_helper(
                i):  # type: MRPTNearestNeighborsIndex
            assert isinstance(inst._descriptor_set, MemoryDescriptorSet)
            assert inst._index_filepath == index_filepath
            assert inst._index_param_filepath == para_filepath
            assert inst._read_only is True
            assert inst._num_trees == 9
            assert inst._depth == 2
            assert inst._rand_seed == 8
            assert inst._pickle_protocol == 0
            assert inst._use_multiprocessing is True
Esempio n. 15
0
 def _make_inst(self, **kwargs: Any) -> MRPTNearestNeighborsIndex:
     """
     Make an instance of MRPTNearestNeighborsIndex
     """
     if 'random_seed' not in kwargs:
         kwargs.update(random_seed=self.RAND_SEED)
     return MRPTNearestNeighborsIndex(MemoryDescriptorSet(), **kwargs)
Esempio n. 16
0
    def test_nn_pathological_example(self) -> None:
        n = 10**4
        dim = 256
        depth = 10
        # L ~ n/2**depth = 10^4 / 2^10 ~ 10
        k = 200
        # 3k/L = 60
        num_trees = 60

        d_set = [DescriptorMemoryElement('test', i) for i in range(n)]
        # Put all descriptors on a line so that different trees get same
        # divisions.
        # noinspection PyTypeChecker
        [d.set_vector(np.full(dim, d.uuid(), dtype=np.float64)) for d in d_set]
        q = DescriptorMemoryElement('q', -1)
        q.set_vector(np.zeros((dim, )))

        di = MemoryDescriptorSet()
        mrpt = MRPTNearestNeighborsIndex(di,
                                         num_trees=num_trees,
                                         depth=depth,
                                         random_seed=0)
        mrpt.build_index(d_set)

        nbrs, dists = mrpt.nn(q, k)
        self.assertEqual(len(nbrs), len(dists))
        # We should get about 10 descriptors back instead of the requested
        # 200
        self.assertLess(len(nbrs), 20)
Esempio n. 17
0
    def test_configuration(self) -> None:
        ex_descr_set = MemoryDescriptorSet()
        ex_i2u_kvs = MemoryKeyValueStore()
        ex_u2i_kvs = MemoryKeyValueStore()
        ex_index_elem = DataMemoryElement()
        ex_index_param_elem = DataMemoryElement()

        i = FaissNearestNeighborsIndex(
            descriptor_set=ex_descr_set, idx2uid_kvs=ex_i2u_kvs,
            uid2idx_kvs=ex_u2i_kvs, index_element=ex_index_elem,
            index_param_element=ex_index_param_elem,
            read_only=True, factory_string=u'some fact str',
            ivf_nprobe=88, use_gpu=False, gpu_id=99, random_seed=8,
        )
        for inst in configuration_test_helper(i):
            assert isinstance(inst._descriptor_set, MemoryDescriptorSet)
            assert isinstance(inst._idx2uid_kvs, MemoryKeyValueStore)
            assert isinstance(inst._uid2idx_kvs, MemoryKeyValueStore)
            assert isinstance(inst._index_element, DataMemoryElement)
            assert isinstance(inst._index_param_element, DataMemoryElement)
            assert inst.read_only is True
            assert isinstance(inst.factory_string, str)
            assert inst.factory_string == 'some fact str'
            assert inst._ivf_nprobe == 88
            assert inst._use_gpu is False
            assert inst._gpu_id == 99
            assert inst.random_seed == 8
Esempio n. 18
0
    def test_nn_small_leaves(self) -> None:
        np.random.seed(0)

        n = 10**4
        dim = 256
        depth = 10
        # L ~ n/2**depth = 10^4 / 2^10 ~ 10
        k = 200
        # 3k/L = 60
        num_trees = 60

        d_set = [DescriptorMemoryElement('test', i) for i in range(n)]
        [d.set_vector(np.random.rand(dim)) for d in d_set]
        q = DescriptorMemoryElement('q', -1)
        q.set_vector(np.zeros((dim, )))

        di = MemoryDescriptorSet()
        mrpt = MRPTNearestNeighborsIndex(di,
                                         num_trees=num_trees,
                                         depth=depth,
                                         random_seed=0)
        mrpt.build_index(d_set)

        nbrs, dists = mrpt.nn(q, k)
        self.assertEqual(len(nbrs), len(dists))
        self.assertEqual(len(nbrs), k)
Esempio n. 19
0
    def test_get_many_descriptor(self) -> None:
        descrs = [
            random_descriptor(),  # [0]
            random_descriptor(),  # [1]
            random_descriptor(),  # [2]
            random_descriptor(),  # [3]
            random_descriptor(),  # [4]
        ]
        index = MemoryDescriptorSet()
        index.add_many_descriptors(descrs)

        # multiple descriptor reference
        r = list(
            index.get_many_descriptors([descrs[0].uuid(), descrs[3].uuid()]))
        self.assertEqual(len(r), 2)
        self.assertEqual(set(r), {descrs[0], descrs[3]})
Esempio n. 20
0
    def test_remove_from_index_invalid_uid(self) -> None:
        # Test that attempting to remove a single invalid UID causes a key
        # error and does not affect index.

        # Descriptors are 1 dim, value == index.
        descriptors = [
            DescriptorMemoryElement(0),
            DescriptorMemoryElement(1),
            DescriptorMemoryElement(2),
            DescriptorMemoryElement(3),
            DescriptorMemoryElement(4),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors:
            d.set_vector(np.ones(1, float) * d.uuid())
        # uid -> descriptor
        expected_dset_table = {
            0: descriptors[0],
            1: descriptors[1],
            2: descriptors[2],
            3: descriptors[3],
            4: descriptors[4],
        }
        # hash int -> set[uid]
        expected_kvs_table = {
            0: {0},
            1: {1},
            2: {2},
            3: {3},
            4: {4},
        }

        d_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs)
        idx.build_index(descriptors)
        # Assert we have the correct expected values
        assert isinstance(idx.descriptor_set, MemoryDescriptorSet)
        self.assertEqual(idx.descriptor_set._table, expected_dset_table)
        assert isinstance(idx.hash2uuids_kvstore, MemoryKeyValueStore)
        self.assertEqual(idx.hash2uuids_kvstore._table, expected_kvs_table)

        # Attempt to remove descriptor with a UID we did not build with.
        self.assertRaisesRegex(
            KeyError, '5',
            idx.remove_from_index, [5]
        )
        # Index should not have been modified.
        self.assertEqual(idx.descriptor_set._table, expected_dset_table)
        self.assertEqual(idx.hash2uuids_kvstore._table, expected_kvs_table)

        # Attempt to remove multiple UIDs, one valid and one invalid
        self.assertRaisesRegex(
            KeyError, '5',
            idx.remove_from_index, [2, 5]
        )
        # Index should not have been modified.
        self.assertEqual(idx.descriptor_set._table, expected_dset_table)
        self.assertEqual(idx.hash2uuids_kvstore._table, expected_kvs_table)
Esempio n. 21
0
 def test_update_index_read_only(self) -> None:
     index = LSHNearestNeighborIndex(DummyHashFunctor(),
                                     MemoryDescriptorSet(),
                                     MemoryKeyValueStore(), read_only=True)
     self.assertRaises(
         ReadOnlyError,
         index._update_index, []
     )
Esempio n. 22
0
 def test_remove_from_index_read_only(self) -> None:
     d_set = MemoryDescriptorSet()
     hash_kvs = MemoryKeyValueStore()
     idx = LSHNearestNeighborIndex(DummyHashFunctor(),
                                   d_set,
                                   hash_kvs,
                                   read_only=True)
     self.assertRaises(ReadOnlyError, idx.remove_from_index,
                       ['uid1', 'uid2'])
Esempio n. 23
0
 def test_remove_from_index_no_existing_index(self) -> None:
     # Test that attempting to remove from an instance with no existing
     # index (meaning empty descriptor-set and key-value-store) results in
     # a key error.
     d_set = MemoryDescriptorSet()
     hash_kvs = MemoryKeyValueStore()
     idx = LSHNearestNeighborIndex(DummyHashFunctor(), d_set, hash_kvs)
     self.assertRaisesRegex(KeyError, 'uid1', idx.remove_from_index,
                            ['uid1'])
Esempio n. 24
0
    def test_update_index_add_new_descriptors(self) -> None:
        # Test that calling update index after a build index causes index
        # components to be properly updated.
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)
        descriptors1 = [
            DescriptorMemoryElement('t', 0),
            DescriptorMemoryElement('t', 1),
            DescriptorMemoryElement('t', 2),
            DescriptorMemoryElement('t', 3),
            DescriptorMemoryElement('t', 4),
        ]
        descriptors2 = [
            DescriptorMemoryElement('t', 5),
            DescriptorMemoryElement('t', 6),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors1 + descriptors2:
            d.set_vector(np.ones(1, float) * d.uuid())

        # Build initial index.
        index.build_index(descriptors1)
        self.assertEqual(descr_set.count(), 5)
        for d in descriptors1:
            self.assertIn(d, descr_set)
        for d in descriptors2:
            self.assertNotIn(d, descr_set)
        # Dummy hash function bins sum of descriptor vectors.
        self.assertEqual(hash_kvs.count(), 5)
        for i in range(5):
            self.assertSetEqual(hash_kvs.get(i), {i})

        # Update index and check that components have new data.
        index.update_index(descriptors2)
        self.assertEqual(descr_set.count(), 7)
        for d in descriptors1 + descriptors2:
            self.assertIn(d, descr_set)
        # Dummy hash function bins sum of descriptor vectors.
        self.assertEqual(hash_kvs.count(), 7)
        for i in range(7):
            self.assertSetEqual(hash_kvs.get(i), {i})
Esempio n. 25
0
    def test_update_index_similar_descriptors(self) -> None:
        """
        Test that updating a built index with similar descriptors (same
        vectors, different UUIDs) results in contained structures having an
        expected state.
        """
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)

        # Similar Descriptors to build and update on (different instances)
        descriptors1 = [
            DescriptorMemoryElement('t', 0).set_vector([0]),
            DescriptorMemoryElement('t', 1).set_vector([1]),
            DescriptorMemoryElement('t', 2).set_vector([2]),
            DescriptorMemoryElement('t', 3).set_vector([3]),
            DescriptorMemoryElement('t', 4).set_vector([4]),
        ]
        descriptors2 = [
            DescriptorMemoryElement('t', 5).set_vector([0]),
            DescriptorMemoryElement('t', 6).set_vector([1]),
            DescriptorMemoryElement('t', 7).set_vector([2]),
            DescriptorMemoryElement('t', 8).set_vector([3]),
            DescriptorMemoryElement('t', 9).set_vector([4]),
        ]

        index.build_index(descriptors1)
        index.update_index(descriptors2)

        assert descr_set.count() == 10
        # Above descriptors should be considered "in" the descriptor set now.
        for d in descriptors1:
            assert d in descr_set
        for d in descriptors2:
            assert d in descr_set
        # Known hashes of the above descriptors should be in the KVS
        assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4}
        assert hash_kvs.get(0) == {0, 5}
        assert hash_kvs.get(1) == {1, 6}
        assert hash_kvs.get(2) == {2, 7}
        assert hash_kvs.get(3) == {3, 8}
        assert hash_kvs.get(4) == {4, 9}
Esempio n. 26
0
    def test_add_many(self) -> None:
        descrs = [
            random_descriptor(),
            random_descriptor(),
            random_descriptor(),
            random_descriptor(),
            random_descriptor(),
        ]
        index = MemoryDescriptorSet()
        index.add_many_descriptors(descrs)

        # Compare code keys of input to code keys in internal table
        self.assertEqual(set(index._table.keys()),
                         set([e.uuid() for e in descrs]))

        # Get the set of descriptors in the internal table and compare it with
        # the set of generated random descriptors.
        r_set = set(index._table.values())
        self.assertEqual(set([e for e in descrs]), r_set)
Esempio n. 27
0
    def test_update_index_duplicate_descriptors(self) -> None:
        """
        Test that updating a built index with the same descriptors results in
        idempotent behavior.
        """
        descr_set = MemoryDescriptorSet()
        hash_kvs = MemoryKeyValueStore()
        index = LSHNearestNeighborIndex(DummyHashFunctor(), descr_set,
                                        hash_kvs)

        # Identical Descriptors to build and update on (different instances)
        descriptors1 = [
            DescriptorMemoryElement('t', 0).set_vector([0]),
            DescriptorMemoryElement('t', 1).set_vector([1]),
            DescriptorMemoryElement('t', 2).set_vector([2]),
            DescriptorMemoryElement('t', 3).set_vector([3]),
            DescriptorMemoryElement('t', 4).set_vector([4]),
        ]
        descriptors2 = [
            DescriptorMemoryElement('t', 0).set_vector([0]),
            DescriptorMemoryElement('t', 1).set_vector([1]),
            DescriptorMemoryElement('t', 2).set_vector([2]),
            DescriptorMemoryElement('t', 3).set_vector([3]),
            DescriptorMemoryElement('t', 4).set_vector([4]),
        ]

        index.build_index(descriptors1)
        index.update_index(descriptors2)

        assert descr_set.count() == 5
        # Above descriptors should be considered "in" the descriptor set now.
        for d in descriptors1:
            assert d in descr_set
        for d in descriptors2:
            assert d in descr_set
        # Known hashes of the above descriptors should be in the KVS
        assert set(hash_kvs.keys()) == {0, 1, 2, 3, 4}
        assert hash_kvs.get(0) == {0}
        assert hash_kvs.get(1) == {1}
        assert hash_kvs.get(2) == {2}
        assert hash_kvs.get(3) == {3}
        assert hash_kvs.get(4) == {4}
Esempio n. 28
0
    def test_init_with_cache(self) -> None:
        d_list = (random_descriptor(), random_descriptor(),
                  random_descriptor(), random_descriptor())
        expected_table = dict((r.uuid(), r) for r in d_list)
        expected_cache = DataMemoryElement(bytes=pickle.dumps(expected_table))

        inst = MemoryDescriptorSet(expected_cache)
        self.assertEqual(len(inst._table), 4)
        self.assertEqual(inst.cache_element, expected_cache)
        self.assertEqual(inst._table, expected_table)
        self.assertEqual(set(inst._table.values()), set(d_list))
Esempio n. 29
0
    def test_remove_from_index_shared_hashes(self) -> None:
        """
        Test that removing a descriptor (by UID) that shares a hash with other
        descriptors does not trigger removal of its hash.
        """
        # Simulate descriptors all hashing to the same hash value: 0
        hash_func = DummyHashFunctor()
        hash_func.get_hash = mock.Mock(return_value=np.asarray(
            [0], bool))  # type: ignore

        d_set = MemoryDescriptorSet()
        hash2uids_kvs = MemoryKeyValueStore()
        idx = LSHNearestNeighborIndex(hash_func, d_set, hash2uids_kvs)

        # Descriptors are 1 dim, value == index.
        descriptors = [
            DescriptorMemoryElement('t', 0),
            DescriptorMemoryElement('t', 1),
            DescriptorMemoryElement('t', 2),
            DescriptorMemoryElement('t', 3),
            DescriptorMemoryElement('t', 4),
        ]
        # Vectors of length 1 for easy dummy hashing prediction.
        for d in descriptors:
            d.set_vector(np.ones(1, float) * d.uuid())
        idx.build_index(descriptors)
        # We expect the descriptor-set and kvs to look like the following now:
        self.assertDictEqual(
            d_set._table, {
                0: descriptors[0],
                1: descriptors[1],
                2: descriptors[2],
                3: descriptors[3],
                4: descriptors[4],
            })
        self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 2, 3, 4}})

        # Mock out hash index as if we had an implementation so we can check
        # call to its remove_from_index method.
        idx.hash_index = mock.Mock(spec=HashIndex)

        idx.remove_from_index([2, 4])

        # Only uid 2 and 4 descriptors should be gone from d-set, kvs should
        # still have the 0 key and its set value should only contain uids 0, 1
        # and 3.  `hash_index.remove_from_index` should not be called because
        # no hashes should be marked for removal.
        self.assertDictEqual(d_set._table, {
            0: descriptors[0],
            1: descriptors[1],
            3: descriptors[3],
        })
        self.assertDictEqual(hash2uids_kvs._table, {0: {0, 1, 3}})
        idx.hash_index.remove_from_index.assert_not_called()
Esempio n. 30
0
    def test_has(self) -> None:
        i = MemoryDescriptorSet()
        descrs = [random_descriptor() for _ in range(10)]
        i.add_many_descriptors(descrs)

        self.assertTrue(i.has_descriptor(descrs[4].uuid()))
        self.assertFalse(i.has_descriptor('not_an_int'))