Example #1
0
    def test_count_nonempty(self):
        bt = SkLearnBallTreeHashIndex()
        # Make 1000 random bit vectors of length 256
        m = np.random.randint(0, 2, 234 * 256).reshape(234, 256)
        bt.build_index(m)

        self.assertEqual(bt.count(), 234)
Example #2
0
 def test_save_model_with_cache(self, m_savez):
     cache_element = DataMemoryElement()
     bt = SkLearnBallTreeHashIndex(cache_element, random_seed=0)
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_true(m_savez.called)
     nose.tools.assert_equal(m_savez.call_count, 1)
Example #3
0
 def test_save_model_no_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex()
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     # Underlying serialization function should not have been called
     # because no cache element set.
     nose.tools.assert_false(m_savez.called)
Example #4
0
 def test_save_model_no_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex()
     m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt._build_bt_internal(m)
     # Underlying serialization function should not have been called
     # because no cache element set.
     self.assertFalse(m_savez.called)
Example #5
0
 def test_save_model_with_cache(self, m_savez):
     cache_element = DataMemoryElement()
     bt = SkLearnBallTreeHashIndex(cache_element, random_seed=0)
     m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt._build_bt_internal(m)
     self.assertTrue(m_savez.called)
     self.assertEqual(m_savez.call_count, 1)
Example #6
0
        def test_get_config(self):
            bt = SkLearnBallTreeHashIndex()
            bt_c = bt.get_config()

            nose.tools.assert_equal(len(bt_c), 3)
            nose.tools.assert_in('file_cache', bt_c)
            nose.tools.assert_in('leaf_size', bt_c)
            nose.tools.assert_in('random_seed', bt_c)

            nose.tools.assert_equal(bt_c['file_cache'], None)
Example #7
0
    def test_get_config(self):
        bt = SkLearnBallTreeHashIndex()
        bt_c = bt.get_config()

        nose.tools.assert_equal(len(bt_c), 3)
        nose.tools.assert_in('cache_element', bt_c)
        nose.tools.assert_in('leaf_size', bt_c)
        nose.tools.assert_in('random_seed', bt_c)

        nose.tools.assert_is_instance(bt_c['cache_element'], dict)
        nose.tools.assert_is_none(bt_c['cache_element']['type'])
Example #8
0
    def test_get_config(self):
        bt = SkLearnBallTreeHashIndex()
        bt_c = bt.get_config()

        self.assertEqual(len(bt_c), 3)
        self.assertIn('cache_element', bt_c)
        self.assertIn('leaf_size', bt_c)
        self.assertIn('random_seed', bt_c)

        self.assertIsInstance(bt_c['cache_element'], dict)
        self.assertIsNone(bt_c['cache_element']['type'])
Example #9
0
    def test_build_index(self):
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        # Make 1000 random bit vectors of length 256
        m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
        bt.build_index(m)

        # deterministically sort index of built and source data to determine
        # that an index was built.
        self.assertIsNotNone(bt.bt)
        np.testing.assert_array_almost_equal(
            sorted(np.array(bt.bt.data).tolist()), sorted(m.tolist()))
Example #10
0
    def test_init_consistency(self):
        # Test that constructing an instance with a configuration yields the
        # same config via ``get_config``.

        # - Default config should be a valid configuration for this impl.
        c = SkLearnBallTreeHashIndex.get_default_config()
        self.assertEqual(
            SkLearnBallTreeHashIndex.from_config(c).get_config(), c)
        # With non-null cache element
        c['cache_element']['type'] = 'DataMemoryElement'
        self.assertEqual(
            SkLearnBallTreeHashIndex.from_config(c).get_config(), c)
Example #11
0
    def test_remove_from_index_invalid_key_single(self):
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Copy post-build index for checking no removal occurred
        bt_data = np.copy(bt.bt.data)

        self.assertRaises(KeyError, bt.remove_from_index, [
            int_to_bit_vector_large(1001, 256),
        ])
        np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
Example #12
0
def main():
    args = cli_parser().parse_args()
    config = smqtk.utils.bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # Loading from configurations
    bit_len = int(config['itq_bit_length'])
    log.info("Loading hash2uuid KeyValue store")
    #: :type: smqtk.representation.KeyValueStore
    hash2uuid_kv_store = smqtk.utils.plugin.from_plugin_config(
        config['hash2uuid_kv_store'],
        smqtk.representation.get_key_value_store_impls())
    log.info("Initializing ball tree")
    btree = SkLearnBallTreeHashIndex.from_config(config['sklearn_balltree'])

    log.info("Computing hash-code vectors")
    hash_vectors = []
    reporter = smqtk.utils.bin_utils.ProgressReporter(log.debug, 1.0)
    reporter.start()
    for h in hash2uuid_kv_store.keys():
        hash_vectors.append(
            smqtk.utils.bit_utils.int_to_bit_vector_large(h, bit_len))
        reporter.increment_report()
    reporter.report()

    log.info("Building ball tree index")
    btree.build_index(hash_vectors)
Example #13
0
 def test_init_without_cache(self):
     i = SkLearnBallTreeHashIndex(cache_element=None, leaf_size=52,
                                  random_seed=42)
     nose.tools.assert_is_none(i.cache_element)
     nose.tools.assert_equal(i.leaf_size, 52)
     nose.tools.assert_equal(i.random_seed, 42)
     nose.tools.assert_is_none(i.bt)
Example #14
0
 def test_default_configuration(self):
     c = SkLearnBallTreeHashIndex.get_default_config()
     nose.tools.assert_equal(len(c), 3)
     nose.tools.assert_is_instance(c['cache_element'], dict)
     nose.tools.assert_is_none(c['cache_element']['type'])
     nose.tools.assert_equal(c['leaf_size'], 40)
     nose.tools.assert_is_none(c['random_seed'])
Example #15
0
 def test_init_without_cache(self):
     i = SkLearnBallTreeHashIndex(cache_element=None, leaf_size=52,
                                  random_seed=42)
     self.assertIsNone(i.cache_element)
     self.assertEqual(i.leaf_size, 52)
     self.assertEqual(i.random_seed, 42)
     self.assertIsNone(i.bt)
Example #16
0
    def test_remove_from_index_invalid_key_multiple(self):
        # Test that mixed valid and invalid keys raises KeyError and does not
        # modify the index.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Copy post-build index for checking no removal occurred
        bt_data = np.copy(bt.bt.data)

        self.assertRaises(KeyError, bt.remove_from_index, [
            int_to_bit_vector_large(42, 256),
            int_to_bit_vector_large(1008, 256),
        ])
        np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
Example #17
0
 def test_default_configuration(self):
     c = SkLearnBallTreeHashIndex.get_default_config()
     self.assertEqual(len(c), 3)
     self.assertIsInstance(c['cache_element'], dict)
     self.assertIsNone(c['cache_element']['type'])
     self.assertEqual(c['leaf_size'], 40)
     self.assertIsNone(c['random_seed'])
Example #18
0
def main():
    args = cli_parser().parse_args()
    config = smqtk.utils.bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # Loading from configurations
    bit_len = int(config['itq_bit_length'])
    log.info("Loading hash2uuid KeyValue store")
    #: :type: smqtk.representation.KeyValueStore
    hash2uuid_kv_store = smqtk.utils.plugin.from_plugin_config(
        config['hash2uuid_kv_store'],
        smqtk.representation.get_key_value_store_impls()
    )
    log.info("Initializing ball tree")
    btree = SkLearnBallTreeHashIndex.from_config(config['sklearn_balltree'])

    log.info("Computing hash-code vectors")
    hash_vectors = []
    reporter = smqtk.utils.bin_utils.ProgressReporter(log.debug, 1.0)
    reporter.start()
    for h in hash2uuid_kv_store.keys():
        hash_vectors.append(
            smqtk.utils.bit_utils.int_to_bit_vector_large(h, bit_len))
        reporter.increment_report()
    reporter.report()

    log.info("Building ball tree index")
    btree.build_index(hash_vectors)
Example #19
0
 def test_invalid_build(self):
     bt = SkLearnBallTreeHashIndex()
     nose.tools.assert_raises(
         ValueError,
         bt.build_index,
         []
     )
Example #20
0
    def test_nn_no_index(self):
        i = SkLearnBallTreeHashIndex()

        self.assertRaisesRegexp(
            ValueError,
            "No index currently set to query from",
            i.nn, [0, 0, 0]
        )
Example #21
0
def default_config():
    return {
        "hash2uuid_kv_store": smqtk.utils.plugin.make_config(
            smqtk.representation.get_key_value_store_impls()
        ),
        "sklearn_balltree": SkLearnBallTreeHashIndex.get_default_config(),
        "itq_bit_length": 256,
    }
Example #22
0
 def test_save_model_with_readonly_cache(self):
     cache_element = DataMemoryElement(readonly=True)
     bt = SkLearnBallTreeHashIndex(cache_element)
     m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     self.assertRaises(
         ValueError,
         bt._build_bt_internal, m
     )
Example #23
0
    def test_nn_no_index(self):
        i = SkLearnBallTreeHashIndex()

        nose.tools.assert_raises_regexp(
            ValueError,
            "No index currently set to query from",
            i.nn, [0, 0, 0]
        )
Example #24
0
 def test_save_model_with_readonly_cache(self):
     cache_element = DataMemoryElement(readonly=True)
     bt = SkLearnBallTreeHashIndex(cache_element)
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     nose.tools.assert_raises(
         ValueError,
         bt.build_index, m
     )
Example #25
0
def default_config():
    return {
        "hash2uuid_kv_store":
        make_default_config(smqtk.representation.KeyValueStore.get_impls()),
        "sklearn_balltree":
        SkLearnBallTreeHashIndex.get_default_config(),
        "itq_bit_length":
        256,
    }
Example #26
0
 def test_init_with_empty_cache(self):
     empty_cache = DataMemoryElement()
     i = SkLearnBallTreeHashIndex(cache_element=empty_cache,
                                  leaf_size=52,
                                  random_seed=42)
     self.assertEqual(i.cache_element, empty_cache)
     self.assertEqual(i.leaf_size, 52)
     self.assertEqual(i.random_seed, 42)
     self.assertIsNone(i.bt)
Example #27
0
        def test_file_cache_type(self):
            # Requites .npz file
            nose.tools.assert_raises(
                ValueError,
                SkLearnBallTreeHashIndex,
                file_cache='some_file.txt'
            )

            SkLearnBallTreeHashIndex(file_cache='some_file.npz')
Example #28
0
def default_config():
    return {
        "hash2uuid_kv_store":
        smqtk.utils.plugin.make_config(
            smqtk.representation.get_key_value_store_impls()),
        "sklearn_balltree":
        SkLearnBallTreeHashIndex.get_default_config(),
        "itq_bit_length":
        256,
    }
Example #29
0
 def test_remove_from_index_no_index(self):
     # A key error should be raised if there is no ball-tree index yet.
     bt = SkLearnBallTreeHashIndex(random_seed=0)
     rm_hash = np.random.randint(0, 2, 256)
     self.assertRaisesRegexp(
         KeyError,
         str(rm_hash[0]),
         bt.remove_from_index,
         [rm_hash]
     )
Example #30
0
def main():
    args = cli_parser().parse_args()

    initialize_logging(logging.getLogger('smqtk'), logging.DEBUG)
    initialize_logging(logging.getLogger('__main__'), logging.DEBUG)
    log = logging.getLogger(__name__)

    hash2uuids_fp = os.path.abspath(args.hash2uuids_fp)
    bit_len = args.bit_len
    leaf_size = args.leaf_size
    rand_seed = args.rand_seed
    balltree_model_fp = os.path.abspath(args.balltree_model_fp)

    assert os.path.isfile(hash2uuids_fp), "Bad path: '%s'" % hash2uuids_fp
    assert os.path.isdir(os.path.dirname(balltree_model_fp)), \
        "Bad path: %s" % balltree_model_fp

    log.debug("hash2uuids_fp    : %s", hash2uuids_fp)
    log.debug("bit_len          : %d", bit_len)
    log.debug("leaf_size        : %d", leaf_size)
    log.debug("rand_seed        : %d", rand_seed)
    log.debug("balltree_model_fp: %s", balltree_model_fp)


    log.info("Loading hash2uuids table")
    with open(hash2uuids_fp) as f:
        hash2uuids = cPickle.load(f)

    log.info("Computing hash-code vectors")
    hash_vectors = []  #[int_to_bit_vector_large(h, bit_len) for h in hash2uuids]
    rs = [0] * 7
    for h in hash2uuids:
        hash_vectors.append( int_to_bit_vector_large(h, bit_len) )
        report_progress(log.debug, rs, 1.)

    log.info("Initializing ball tree")
    btree = SkLearnBallTreeHashIndex(balltree_model_fp, leaf_size, rand_seed)

    log.info("Building ball tree")
    btree.build_index(hash_vectors)
Example #31
0
    def test_remove_from_index_last_element_with_cache(self):
        """
        Test removing final element also clears the cache element.
        """
        c = DataMemoryElement()
        bt = SkLearnBallTreeHashIndex(cache_element=c, random_seed=0)
        index = np.ndarray((1, 256), bool)
        index[0] = int_to_bit_vector_large(1, 256)

        bt.build_index(index)
        self.assertEqual(bt.count(), 1)
        self.assertFalse(c.is_empty())

        bt.remove_from_index(index)
        self.assertEqual(bt.count(), 0)
        self.assertTrue(c.is_empty())
Example #32
0
        def test_model_reload(self):
            fd, fp = tempfile.mkstemp('.npz')
            os.close(fd)
            os.remove(fp)  # shouldn't exist before construction
            try:
                bt = SkLearnBallTreeHashIndex(fp)
                m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
                bt.build_index(m)
                q = numpy.random.randint(0, 2, 256).astype(bool)
                bt_neighbors, bt_dists = bt.nn(q, 10)

                bt2 = SkLearnBallTreeHashIndex(fp)
                bt2_neighbors, bt2_dists = bt2.nn(q, 10)

                nose.tools.assert_is_not(bt, bt2)
                nose.tools.assert_is_not(bt.bt, bt2.bt)
                numpy.testing.assert_equal(bt2_neighbors, bt_neighbors)
                numpy.testing.assert_equal(bt2_dists, bt_dists)
            finally:
                os.remove(fp)
Example #33
0
import os
import tempfile
import unittest

import mock
import nose.tools
import numpy

from smqtk.algorithms.nn_index.hash_index.sklearn_balltree import \
    SkLearnBallTreeHashIndex


__author__ = "*****@*****.**"


if SkLearnBallTreeHashIndex.is_usable():

    class TestBallTreeHashIndex (unittest.TestCase):

        def test_file_cache_type(self):
            # Requites .npz file
            nose.tools.assert_raises(
                ValueError,
                SkLearnBallTreeHashIndex,
                file_cache='some_file.txt'
            )

            SkLearnBallTreeHashIndex(file_cache='some_file.npz')

        @mock.patch('smqtk.algorithms.nn_index.hash_index.sklearn_balltree.numpy.savez')
        def test_save_model_no_cache(self, m_savez):
Example #34
0
 def test_save_model_no_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex()
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_false(m_savez.called)
Example #35
0
 def test_save_model_with_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex('some_file.npz')
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_true(m_savez.called)
     nose.tools.assert_equal(m_savez.call_count, 1)