Ejemplo n.º 1
0
    def test_nn_pathological_example(self) -> None:
        n = 10**4
        dim = 256
        depth = 10
        # L ~ n/2**depth = 10^4 / 2^10 ~ 10
        k = 200
        # 3k/L = 60
        num_trees = 60

        d_set = [DescriptorMemoryElement('test', i) for i in range(n)]
        # Put all descriptors on a line so that different trees get same
        # divisions.
        # noinspection PyTypeChecker
        [d.set_vector(np.full(dim, d.uuid(), dtype=np.float64)) for d in d_set]
        q = DescriptorMemoryElement('q', -1)
        q.set_vector(np.zeros((dim, )))

        di = MemoryDescriptorSet()
        mrpt = MRPTNearestNeighborsIndex(di,
                                         num_trees=num_trees,
                                         depth=depth,
                                         random_seed=0)
        mrpt.build_index(d_set)

        nbrs, dists = mrpt.nn(q, k)
        self.assertEqual(len(nbrs), len(dists))
        # We should get about 10 descriptors back instead of the requested
        # 200
        self.assertLess(len(nbrs), 20)
Ejemplo n.º 2
0
    def test_nn_small_leaves(self) -> None:
        np.random.seed(0)

        n = 10**4
        dim = 256
        depth = 10
        # L ~ n/2**depth = 10^4 / 2^10 ~ 10
        k = 200
        # 3k/L = 60
        num_trees = 60

        d_set = [DescriptorMemoryElement('test', i) for i in range(n)]
        [d.set_vector(np.random.rand(dim)) for d in d_set]
        q = DescriptorMemoryElement('q', -1)
        q.set_vector(np.zeros((dim, )))

        di = MemoryDescriptorSet()
        mrpt = MRPTNearestNeighborsIndex(di,
                                         num_trees=num_trees,
                                         depth=depth,
                                         random_seed=0)
        mrpt.build_index(d_set)

        nbrs, dists = mrpt.nn(q, k)
        self.assertEqual(len(nbrs), len(dists))
        self.assertEqual(len(nbrs), k)
Ejemplo n.º 3
0
    def test_configuration(self) -> None:
        index_filepath = osp.abspath(osp.expanduser('index_filepath'))
        para_filepath = osp.abspath(osp.expanduser('param_fp'))

        i = MRPTNearestNeighborsIndex(
            descriptor_set=MemoryDescriptorSet(),
            index_filepath=index_filepath,
            parameters_filepath=para_filepath,
            read_only=True,
            num_trees=9,
            depth=2,
            random_seed=8,
            pickle_protocol=0,
            use_multiprocessing=True,
        )
        for inst in configuration_test_helper(
                i):  # type: MRPTNearestNeighborsIndex
            assert isinstance(inst._descriptor_set, MemoryDescriptorSet)
            assert inst._index_filepath == index_filepath
            assert inst._index_param_filepath == para_filepath
            assert inst._read_only is True
            assert inst._num_trees == 9
            assert inst._depth == 2
            assert inst._rand_seed == 8
            assert inst._pickle_protocol == 0
            assert inst._use_multiprocessing is True
Ejemplo n.º 4
0
 def _make_inst(self, **kwargs: Any) -> MRPTNearestNeighborsIndex:
     """
     Make an instance of MRPTNearestNeighborsIndex
     """
     if 'random_seed' not in kwargs:
         kwargs.update(random_seed=self.RAND_SEED)
     return MRPTNearestNeighborsIndex(MemoryDescriptorSet(), **kwargs)
Ejemplo n.º 5
0
    def test_nn_many_descriptors(self) -> None:
        np.random.seed(0)

        n = 10**4
        dim = 256
        depth = 5
        num_trees = 10

        d_set = [DescriptorMemoryElement('test', i) for i in range(n)]
        [d.set_vector(np.random.rand(dim)) for d in d_set]
        q = DescriptorMemoryElement('q', -1)
        q.set_vector(np.zeros((dim, )))

        di = MemoryDescriptorSet()
        mrpt = MRPTNearestNeighborsIndex(di,
                                         num_trees=num_trees,
                                         depth=depth,
                                         random_seed=0)
        mrpt.build_index(d_set)

        nbrs, dists = mrpt.nn(q, 10)
        self.assertEqual(len(nbrs), len(dists))
        self.assertEqual(len(nbrs), 10)