def test_count_nonempty(self): bt = SkLearnBallTreeHashIndex() # Make 1000 random bit vectors of length 256 m = np.random.randint(0, 2, 234 * 256).reshape(234, 256) bt.build_index(m) self.assertEqual(bt.count(), 234)
def test_save_model_with_cache(self, m_savez): cache_element = DataMemoryElement() bt = SkLearnBallTreeHashIndex(cache_element, random_seed=0) m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) nose.tools.assert_true(m_savez.called) nose.tools.assert_equal(m_savez.call_count, 1)
def test_save_model_no_cache(self, m_savez): bt = SkLearnBallTreeHashIndex() m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) # Underlying serialization function should not have been called # because no cache element set. nose.tools.assert_false(m_savez.called)
def test_save_model_no_cache(self, m_savez): bt = SkLearnBallTreeHashIndex() m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt._build_bt_internal(m) # Underlying serialization function should not have been called # because no cache element set. self.assertFalse(m_savez.called)
def test_save_model_with_cache(self, m_savez): cache_element = DataMemoryElement() bt = SkLearnBallTreeHashIndex(cache_element, random_seed=0) m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt._build_bt_internal(m) self.assertTrue(m_savez.called) self.assertEqual(m_savez.call_count, 1)
def test_get_config(self): bt = SkLearnBallTreeHashIndex() bt_c = bt.get_config() nose.tools.assert_equal(len(bt_c), 3) nose.tools.assert_in('file_cache', bt_c) nose.tools.assert_in('leaf_size', bt_c) nose.tools.assert_in('random_seed', bt_c) nose.tools.assert_equal(bt_c['file_cache'], None)
def test_get_config(self): bt = SkLearnBallTreeHashIndex() bt_c = bt.get_config() nose.tools.assert_equal(len(bt_c), 3) nose.tools.assert_in('cache_element', bt_c) nose.tools.assert_in('leaf_size', bt_c) nose.tools.assert_in('random_seed', bt_c) nose.tools.assert_is_instance(bt_c['cache_element'], dict) nose.tools.assert_is_none(bt_c['cache_element']['type'])
def test_get_config(self): bt = SkLearnBallTreeHashIndex() bt_c = bt.get_config() self.assertEqual(len(bt_c), 3) self.assertIn('cache_element', bt_c) self.assertIn('leaf_size', bt_c) self.assertIn('random_seed', bt_c) self.assertIsInstance(bt_c['cache_element'], dict) self.assertIsNone(bt_c['cache_element']['type'])
def test_build_index(self): bt = SkLearnBallTreeHashIndex(random_seed=0) # Make 1000 random bit vectors of length 256 m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) # deterministically sort index of built and source data to determine # that an index was built. self.assertIsNotNone(bt.bt) np.testing.assert_array_almost_equal( sorted(np.array(bt.bt.data).tolist()), sorted(m.tolist()))
def test_init_consistency(self): # Test that constructing an instance with a configuration yields the # same config via ``get_config``. # - Default config should be a valid configuration for this impl. c = SkLearnBallTreeHashIndex.get_default_config() self.assertEqual( SkLearnBallTreeHashIndex.from_config(c).get_config(), c) # With non-null cache element c['cache_element']['type'] = 'DataMemoryElement' self.assertEqual( SkLearnBallTreeHashIndex.from_config(c).get_config(), c)
def test_remove_from_index_invalid_key_single(self): bt = SkLearnBallTreeHashIndex(random_seed=0) index = np.ndarray((1000, 256), bool) for i in range(1000): index[i] = int_to_bit_vector_large(i, 256) bt.build_index(index) # Copy post-build index for checking no removal occurred bt_data = np.copy(bt.bt.data) self.assertRaises(KeyError, bt.remove_from_index, [ int_to_bit_vector_large(1001, 256), ]) np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
def main(): args = cli_parser().parse_args() config = smqtk.utils.bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # Loading from configurations bit_len = int(config['itq_bit_length']) log.info("Loading hash2uuid KeyValue store") #: :type: smqtk.representation.KeyValueStore hash2uuid_kv_store = smqtk.utils.plugin.from_plugin_config( config['hash2uuid_kv_store'], smqtk.representation.get_key_value_store_impls()) log.info("Initializing ball tree") btree = SkLearnBallTreeHashIndex.from_config(config['sklearn_balltree']) log.info("Computing hash-code vectors") hash_vectors = [] reporter = smqtk.utils.bin_utils.ProgressReporter(log.debug, 1.0) reporter.start() for h in hash2uuid_kv_store.keys(): hash_vectors.append( smqtk.utils.bit_utils.int_to_bit_vector_large(h, bit_len)) reporter.increment_report() reporter.report() log.info("Building ball tree index") btree.build_index(hash_vectors)
def test_init_without_cache(self): i = SkLearnBallTreeHashIndex(cache_element=None, leaf_size=52, random_seed=42) nose.tools.assert_is_none(i.cache_element) nose.tools.assert_equal(i.leaf_size, 52) nose.tools.assert_equal(i.random_seed, 42) nose.tools.assert_is_none(i.bt)
def test_default_configuration(self): c = SkLearnBallTreeHashIndex.get_default_config() nose.tools.assert_equal(len(c), 3) nose.tools.assert_is_instance(c['cache_element'], dict) nose.tools.assert_is_none(c['cache_element']['type']) nose.tools.assert_equal(c['leaf_size'], 40) nose.tools.assert_is_none(c['random_seed'])
def test_init_without_cache(self): i = SkLearnBallTreeHashIndex(cache_element=None, leaf_size=52, random_seed=42) self.assertIsNone(i.cache_element) self.assertEqual(i.leaf_size, 52) self.assertEqual(i.random_seed, 42) self.assertIsNone(i.bt)
def test_remove_from_index_invalid_key_multiple(self): # Test that mixed valid and invalid keys raises KeyError and does not # modify the index. bt = SkLearnBallTreeHashIndex(random_seed=0) index = np.ndarray((1000, 256), bool) for i in range(1000): index[i] = int_to_bit_vector_large(i, 256) bt.build_index(index) # Copy post-build index for checking no removal occurred bt_data = np.copy(bt.bt.data) self.assertRaises(KeyError, bt.remove_from_index, [ int_to_bit_vector_large(42, 256), int_to_bit_vector_large(1008, 256), ]) np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
def test_default_configuration(self): c = SkLearnBallTreeHashIndex.get_default_config() self.assertEqual(len(c), 3) self.assertIsInstance(c['cache_element'], dict) self.assertIsNone(c['cache_element']['type']) self.assertEqual(c['leaf_size'], 40) self.assertIsNone(c['random_seed'])
def main(): args = cli_parser().parse_args() config = smqtk.utils.bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # Loading from configurations bit_len = int(config['itq_bit_length']) log.info("Loading hash2uuid KeyValue store") #: :type: smqtk.representation.KeyValueStore hash2uuid_kv_store = smqtk.utils.plugin.from_plugin_config( config['hash2uuid_kv_store'], smqtk.representation.get_key_value_store_impls() ) log.info("Initializing ball tree") btree = SkLearnBallTreeHashIndex.from_config(config['sklearn_balltree']) log.info("Computing hash-code vectors") hash_vectors = [] reporter = smqtk.utils.bin_utils.ProgressReporter(log.debug, 1.0) reporter.start() for h in hash2uuid_kv_store.keys(): hash_vectors.append( smqtk.utils.bit_utils.int_to_bit_vector_large(h, bit_len)) reporter.increment_report() reporter.report() log.info("Building ball tree index") btree.build_index(hash_vectors)
def test_invalid_build(self): bt = SkLearnBallTreeHashIndex() nose.tools.assert_raises( ValueError, bt.build_index, [] )
def test_nn_no_index(self): i = SkLearnBallTreeHashIndex() self.assertRaisesRegexp( ValueError, "No index currently set to query from", i.nn, [0, 0, 0] )
def default_config(): return { "hash2uuid_kv_store": smqtk.utils.plugin.make_config( smqtk.representation.get_key_value_store_impls() ), "sklearn_balltree": SkLearnBallTreeHashIndex.get_default_config(), "itq_bit_length": 256, }
def test_save_model_with_readonly_cache(self): cache_element = DataMemoryElement(readonly=True) bt = SkLearnBallTreeHashIndex(cache_element) m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256) self.assertRaises( ValueError, bt._build_bt_internal, m )
def test_nn_no_index(self): i = SkLearnBallTreeHashIndex() nose.tools.assert_raises_regexp( ValueError, "No index currently set to query from", i.nn, [0, 0, 0] )
def test_save_model_with_readonly_cache(self): cache_element = DataMemoryElement(readonly=True) bt = SkLearnBallTreeHashIndex(cache_element) m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) nose.tools.assert_raises( ValueError, bt.build_index, m )
def default_config(): return { "hash2uuid_kv_store": make_default_config(smqtk.representation.KeyValueStore.get_impls()), "sklearn_balltree": SkLearnBallTreeHashIndex.get_default_config(), "itq_bit_length": 256, }
def test_init_with_empty_cache(self): empty_cache = DataMemoryElement() i = SkLearnBallTreeHashIndex(cache_element=empty_cache, leaf_size=52, random_seed=42) self.assertEqual(i.cache_element, empty_cache) self.assertEqual(i.leaf_size, 52) self.assertEqual(i.random_seed, 42) self.assertIsNone(i.bt)
def test_file_cache_type(self): # Requites .npz file nose.tools.assert_raises( ValueError, SkLearnBallTreeHashIndex, file_cache='some_file.txt' ) SkLearnBallTreeHashIndex(file_cache='some_file.npz')
def default_config(): return { "hash2uuid_kv_store": smqtk.utils.plugin.make_config( smqtk.representation.get_key_value_store_impls()), "sklearn_balltree": SkLearnBallTreeHashIndex.get_default_config(), "itq_bit_length": 256, }
def test_remove_from_index_no_index(self): # A key error should be raised if there is no ball-tree index yet. bt = SkLearnBallTreeHashIndex(random_seed=0) rm_hash = np.random.randint(0, 2, 256) self.assertRaisesRegexp( KeyError, str(rm_hash[0]), bt.remove_from_index, [rm_hash] )
def main(): args = cli_parser().parse_args() initialize_logging(logging.getLogger('smqtk'), logging.DEBUG) initialize_logging(logging.getLogger('__main__'), logging.DEBUG) log = logging.getLogger(__name__) hash2uuids_fp = os.path.abspath(args.hash2uuids_fp) bit_len = args.bit_len leaf_size = args.leaf_size rand_seed = args.rand_seed balltree_model_fp = os.path.abspath(args.balltree_model_fp) assert os.path.isfile(hash2uuids_fp), "Bad path: '%s'" % hash2uuids_fp assert os.path.isdir(os.path.dirname(balltree_model_fp)), \ "Bad path: %s" % balltree_model_fp log.debug("hash2uuids_fp : %s", hash2uuids_fp) log.debug("bit_len : %d", bit_len) log.debug("leaf_size : %d", leaf_size) log.debug("rand_seed : %d", rand_seed) log.debug("balltree_model_fp: %s", balltree_model_fp) log.info("Loading hash2uuids table") with open(hash2uuids_fp) as f: hash2uuids = cPickle.load(f) log.info("Computing hash-code vectors") hash_vectors = [] #[int_to_bit_vector_large(h, bit_len) for h in hash2uuids] rs = [0] * 7 for h in hash2uuids: hash_vectors.append( int_to_bit_vector_large(h, bit_len) ) report_progress(log.debug, rs, 1.) log.info("Initializing ball tree") btree = SkLearnBallTreeHashIndex(balltree_model_fp, leaf_size, rand_seed) log.info("Building ball tree") btree.build_index(hash_vectors)
def test_remove_from_index_last_element_with_cache(self): """ Test removing final element also clears the cache element. """ c = DataMemoryElement() bt = SkLearnBallTreeHashIndex(cache_element=c, random_seed=0) index = np.ndarray((1, 256), bool) index[0] = int_to_bit_vector_large(1, 256) bt.build_index(index) self.assertEqual(bt.count(), 1) self.assertFalse(c.is_empty()) bt.remove_from_index(index) self.assertEqual(bt.count(), 0) self.assertTrue(c.is_empty())
def test_model_reload(self): fd, fp = tempfile.mkstemp('.npz') os.close(fd) os.remove(fp) # shouldn't exist before construction try: bt = SkLearnBallTreeHashIndex(fp) m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) q = numpy.random.randint(0, 2, 256).astype(bool) bt_neighbors, bt_dists = bt.nn(q, 10) bt2 = SkLearnBallTreeHashIndex(fp) bt2_neighbors, bt2_dists = bt2.nn(q, 10) nose.tools.assert_is_not(bt, bt2) nose.tools.assert_is_not(bt.bt, bt2.bt) numpy.testing.assert_equal(bt2_neighbors, bt_neighbors) numpy.testing.assert_equal(bt2_dists, bt_dists) finally: os.remove(fp)
import os import tempfile import unittest import mock import nose.tools import numpy from smqtk.algorithms.nn_index.hash_index.sklearn_balltree import \ SkLearnBallTreeHashIndex __author__ = "*****@*****.**" if SkLearnBallTreeHashIndex.is_usable(): class TestBallTreeHashIndex (unittest.TestCase): def test_file_cache_type(self): # Requites .npz file nose.tools.assert_raises( ValueError, SkLearnBallTreeHashIndex, file_cache='some_file.txt' ) SkLearnBallTreeHashIndex(file_cache='some_file.npz') @mock.patch('smqtk.algorithms.nn_index.hash_index.sklearn_balltree.numpy.savez') def test_save_model_no_cache(self, m_savez):
def test_save_model_no_cache(self, m_savez): bt = SkLearnBallTreeHashIndex() m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) nose.tools.assert_false(m_savez.called)
def test_save_model_with_cache(self, m_savez): bt = SkLearnBallTreeHashIndex('some_file.npz') m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) nose.tools.assert_true(m_savez.called) nose.tools.assert_equal(m_savez.call_count, 1)