def test_nn_pathological_example(self): n = 10**4 dim = 256 depth = 10 # L ~ n/2**depth = 10^4 / 2^10 ~ 10 k = 200 # 3k/L = 60 num_trees = 60 d_set = [DescriptorMemoryElement('test', i) for i in range(n)] # Put all descriptors on a line so that different trees get same # divisions. # noinspection PyTypeChecker [d.set_vector(np.full(dim, d.uuid(), dtype=np.float64)) for d in d_set] q = DescriptorMemoryElement('q', -1) q.set_vector(np.zeros((dim, ))) di = MemoryDescriptorSet() mrpt = MRPTNearestNeighborsIndex(di, num_trees=num_trees, depth=depth, random_seed=0) mrpt.build_index(d_set) nbrs, dists = mrpt.nn(q, k) self.assertEqual(len(nbrs), len(dists)) # We should get about 10 descriptors back instead of the requested # 200 self.assertLess(len(nbrs), 20)
def test_nn_small_leaves(self): np.random.seed(0) n = 10**4 dim = 256 depth = 10 # L ~ n/2**depth = 10^4 / 2^10 ~ 10 k = 200 # 3k/L = 60 num_trees = 60 d_set = [DescriptorMemoryElement('test', i) for i in range(n)] [d.set_vector(np.random.rand(dim)) for d in d_set] q = DescriptorMemoryElement('q', -1) q.set_vector(np.zeros((dim, ))) di = MemoryDescriptorSet() mrpt = MRPTNearestNeighborsIndex(di, num_trees=num_trees, depth=depth, random_seed=0) mrpt.build_index(d_set) nbrs, dists = mrpt.nn(q, k) self.assertEqual(len(nbrs), len(dists)) self.assertEqual(len(nbrs), k)
def test_configuration(self): index_filepath = osp.abspath(osp.expanduser('index_filepath')) para_filepath = osp.abspath(osp.expanduser('param_fp')) i = MRPTNearestNeighborsIndex( descriptor_set=MemoryDescriptorSet(), index_filepath=index_filepath, parameters_filepath=para_filepath, read_only=True, num_trees=9, depth=2, random_seed=8, pickle_protocol=0, use_multiprocessing=True, ) for inst in configuration_test_helper( i): # type: MRPTNearestNeighborsIndex assert isinstance(inst._descriptor_set, MemoryDescriptorSet) assert inst._index_filepath == index_filepath assert inst._index_param_filepath == para_filepath assert inst._read_only == True assert inst._num_trees == 9 assert inst._depth == 2 assert inst._rand_seed == 8 assert inst._pickle_protocol == 0 assert inst._use_multiprocessing == True
def _make_inst(self, **kwargs): """ Make an instance of MRPTNearestNeighborsIndex """ if 'random_seed' not in kwargs: kwargs.update(random_seed=self.RAND_SEED) return MRPTNearestNeighborsIndex(MemoryDescriptorSet(), **kwargs)
def test_configuration(self): index_filepath = osp.abspath(osp.expanduser('index_filepath')) para_filepath = osp.abspath(osp.expanduser('param_fp')) # Make configuration based on default c = MRPTNearestNeighborsIndex.get_default_config() c['index_filepath'] = index_filepath c['parameters_filepath'] = para_filepath c['descriptor_set']['type'] = 'MemoryDescriptorIndex' # Build based on configuration index = MRPTNearestNeighborsIndex.from_config(c) ntools.assert_equal(index._index_filepath, index_filepath) ntools.assert_equal(index._index_param_filepath, para_filepath) c2 = index.get_config() ntools.assert_equal(c, c2)
def test_configuration(self): index_filepath = osp.abspath(osp.expanduser('index_filepath')) para_filepath = osp.abspath(osp.expanduser('param_fp')) # Make configuration based on default c = MRPTNearestNeighborsIndex.get_default_config() c['index_filepath'] = index_filepath c['parameters_filepath'] = para_filepath c['descriptor_set']['type'] = 'MemoryDescriptorIndex' # Build based on configuration index = MRPTNearestNeighborsIndex.from_config(c) ntools.assert_equal(index._index_filepath, index_filepath) ntools.assert_equal(index._index_param_filepath, para_filepath) # Test that constructing a new instance from ``index``'s config yields # an index with the same configuration (idempotent). index2 = MRPTNearestNeighborsIndex.from_config(index.get_config()) ntools.assert_equal(index.get_config(), index2.get_config())
def test_many_descriptors(self): np.random.seed(0) n = 10 ** 4 dim = 256 depth = 5 num_trees = 10 d_index = [DescriptorMemoryElement('test', i) for i in range(n)] [d.set_vector(np.random.rand(dim)) for d in d_index] q = DescriptorMemoryElement('q', -1) q.set_vector(np.zeros((dim,))) di = MemoryDescriptorIndex() mrpt = MRPTNearestNeighborsIndex( di, num_trees=num_trees, depth=depth, random_seed=0) mrpt.build_index(d_index) nbrs, dists = mrpt.nn(q, 10) ntools.assert_equal(len(nbrs), len(dists)) ntools.assert_equal(len(nbrs), 10)