Пример #1
0
    def test_count_nonempty(self):
        bt = SkLearnBallTreeHashIndex()
        # Make 1000 random bit vectors of length 256
        m = np.random.randint(0, 2, 234 * 256).reshape(234, 256)
        bt.build_index(m)

        self.assertEqual(bt.count(), 234)
Пример #2
0
 def test_save_model_with_cache(self, m_savez):
     cache_element = DataMemoryElement()
     bt = SkLearnBallTreeHashIndex(cache_element, random_seed=0)
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_true(m_savez.called)
     nose.tools.assert_equal(m_savez.call_count, 1)
Пример #3
0
 def test_save_model_no_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex()
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     # Underlying serialization function should not have been called
     # because no cache element set.
     nose.tools.assert_false(m_savez.called)
Пример #4
0
    def test_build_index(self):
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        # Make 1000 random bit vectors of length 256
        m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
        bt.build_index(m)

        # deterministically sort index of built and source data to determine
        # that an index was built.
        self.assertIsNotNone(bt.bt)
        np.testing.assert_array_almost_equal(
            sorted(np.array(bt.bt.data).tolist()), sorted(m.tolist()))
Пример #5
0
    def test_remove_from_index_invalid_key_single(self):
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Copy post-build index for checking no removal occurred
        bt_data = np.copy(bt.bt.data)

        self.assertRaises(KeyError, bt.remove_from_index, [
            int_to_bit_vector_large(1001, 256),
        ])
        np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
Пример #6
0
    def test_remove_from_index_last_element_with_cache(self):
        """
        Test removing final element also clears the cache element.
        """
        c = DataMemoryElement()
        bt = SkLearnBallTreeHashIndex(cache_element=c, random_seed=0)
        index = np.ndarray((1, 256), bool)
        index[0] = int_to_bit_vector_large(1, 256)

        bt.build_index(index)
        self.assertEqual(bt.count(), 1)
        self.assertFalse(c.is_empty())

        bt.remove_from_index(index)
        self.assertEqual(bt.count(), 0)
        self.assertTrue(c.is_empty())
Пример #7
0
    def test_remove_from_index_invalid_key_multiple(self):
        # Test that mixed valid and invalid keys raises KeyError and does not
        # modify the index.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Copy post-build index for checking no removal occurred
        bt_data = np.copy(bt.bt.data)

        self.assertRaises(KeyError, bt.remove_from_index, [
            int_to_bit_vector_large(42, 256),
            int_to_bit_vector_large(1008, 256),
        ])
        np.testing.assert_array_equal(bt_data, np.asarray(bt.bt.data))
Пример #8
0
        def test_model_reload(self):
            fd, fp = tempfile.mkstemp('.npz')
            os.close(fd)
            os.remove(fp)  # shouldn't exist before construction
            try:
                bt = SkLearnBallTreeHashIndex(fp)
                m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
                bt.build_index(m)
                q = numpy.random.randint(0, 2, 256).astype(bool)
                bt_neighbors, bt_dists = bt.nn(q, 10)

                bt2 = SkLearnBallTreeHashIndex(fp)
                bt2_neighbors, bt2_dists = bt2.nn(q, 10)

                nose.tools.assert_is_not(bt, bt2)
                nose.tools.assert_is_not(bt.bt, bt2.bt)
                numpy.testing.assert_equal(bt2_neighbors, bt_neighbors)
                numpy.testing.assert_equal(bt2_dists, bt_dists)
            finally:
                os.remove(fp)
Пример #9
0
    def test_load_model(self):
        # Create two index instances, building model with one, and loading
        # the other with the cache of the first instance. Each should have
        # distinct model instances, but should otherwise have equal model
        # values and parameters.
        cache_element = DataMemoryElement()
        bt1 = SkLearnBallTreeHashIndex(cache_element, random_seed=0)
        m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
        bt1.build_index(m)

        bt2 = SkLearnBallTreeHashIndex(cache_element)
        self.assertIsNotNone(bt2.bt)

        q = np.random.randint(0, 2, 256).astype(bool)
        bt_neighbors, bt_dists = bt1.nn(q, 10)
        bt2_neighbors, bt2_dists = bt2.nn(q, 10)

        self.assertIsNot(bt1, bt2)
        self.assertIsNot(bt1.bt, bt2.bt)
        np.testing.assert_equal(bt2_neighbors, bt_neighbors)
        np.testing.assert_equal(bt2_dists, bt_dists)
Пример #10
0
    def test_update_index_additive(self):
        # Test updating an existing index, i.e. rebuilding using the union of
        # previous and new data.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        # Make 1000 random bit vectors of length 256
        m1 = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256)\
               .astype(bool)
        m2 = np.random.randint(0, 2, 100 * 256).reshape(100, 256).astype(bool)

        # Build initial index
        bt.build_index(m1)
        # Current model should only contain m1's data.
        np.testing.assert_array_almost_equal(
            sorted(np.array(bt.bt.data).tolist()), sorted(m1.tolist()))

        # "Update" index with new hashes
        bt.update_index(m2)
        # New model should contain the union of the data.
        np.testing.assert_array_almost_equal(
            sorted(np.array(bt.bt.data).tolist()),
            sorted(np.concatenate([m1, m2], 0).tolist()))
Пример #11
0
    def test_remove_from_index(self):
        # Test that we actually remove from the index.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Copy post-build index for checking no removal occurred
        bt_data = np.copy(bt.bt.data)

        bt.remove_from_index([
            int_to_bit_vector_large(42, 256),
            int_to_bit_vector_large(998, 256),
        ])
        # Make sure expected arrays are missing from data block.
        new_data = np.asarray(bt.bt.data)
        self.assertEqual(new_data.shape, (998, 256))
        new_data_set = set(tuple(r) for r in new_data.tolist())
        self.assertNotIn(tuple(int_to_bit_vector_large(42, 256)), new_data_set)
        self.assertNotIn(tuple(int_to_bit_vector_large(998, 256)),
                         new_data_set)
Пример #12
0
def main():
    args = cli_parser().parse_args()

    initialize_logging(logging.getLogger('smqtk'), logging.DEBUG)
    initialize_logging(logging.getLogger('__main__'), logging.DEBUG)
    log = logging.getLogger(__name__)

    hash2uuids_fp = os.path.abspath(args.hash2uuids_fp)
    bit_len = args.bit_len
    leaf_size = args.leaf_size
    rand_seed = args.rand_seed
    balltree_model_fp = os.path.abspath(args.balltree_model_fp)

    assert os.path.isfile(hash2uuids_fp), "Bad path: '%s'" % hash2uuids_fp
    assert os.path.isdir(os.path.dirname(balltree_model_fp)), \
        "Bad path: %s" % balltree_model_fp

    log.debug("hash2uuids_fp    : %s", hash2uuids_fp)
    log.debug("bit_len          : %d", bit_len)
    log.debug("leaf_size        : %d", leaf_size)
    log.debug("rand_seed        : %d", rand_seed)
    log.debug("balltree_model_fp: %s", balltree_model_fp)


    log.info("Loading hash2uuids table")
    with open(hash2uuids_fp) as f:
        hash2uuids = cPickle.load(f)

    log.info("Computing hash-code vectors")
    hash_vectors = []  #[int_to_bit_vector_large(h, bit_len) for h in hash2uuids]
    rs = [0] * 7
    for h in hash2uuids:
        hash_vectors.append( int_to_bit_vector_large(h, bit_len) )
        report_progress(log.debug, rs, 1.)

    log.info("Initializing ball tree")
    btree = SkLearnBallTreeHashIndex(balltree_model_fp, leaf_size, rand_seed)

    log.info("Building ball tree")
    btree.build_index(hash_vectors)
Пример #13
0
    def test_remove_from_index(self):
        # Test that we actually remove from the index.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # BallTree data should now contain 1000 entries
        self.assertEqual(bt.bt.data.shape, (1000, 256))

        bt.remove_from_index([
            int_to_bit_vector_large(42, 256),
            int_to_bit_vector_large(998, 256),
        ])
        # Make sure data block is of the expected shape (two rows shorter)
        new_data = np.asarray(bt.bt.data)
        self.assertEqual(new_data.shape, (998, 256))
        # Make sure expected arrays are missing from data block.
        new_data_set = set(tuple(r) for r in new_data.tolist())
        self.assertNotIn(tuple(int_to_bit_vector_large(42, 256)), new_data_set)
        self.assertNotIn(tuple(int_to_bit_vector_large(998, 256)),
                         new_data_set)
Пример #14
0
    def test_remove_from_index_last_element(self):
        """
        Test removing the final the only element / final elements from the
        index.
        """
        # Add one hash, remove one hash.
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1, 256), bool)
        index[0] = int_to_bit_vector_large(1, 256)
        bt.build_index(index)
        self.assertEqual(bt.count(), 1)
        bt.remove_from_index(index)
        self.assertEqual(bt.count(), 0)
        self.assertIsNone(bt.bt)

        # Add many hashes, remove many hashes in batches until zero
        bt = SkLearnBallTreeHashIndex(random_seed=0)
        index = np.ndarray((1000, 256), bool)
        for i in range(1000):
            index[i] = int_to_bit_vector_large(i, 256)
        bt.build_index(index)
        # Remove first 250
        bt.remove_from_index(index[:250])
        self.assertEqual(bt.count(), 750)
        self.assertIsNotNone(bt.bt)
        # Remove second 250
        bt.remove_from_index(index[250:500])
        self.assertEqual(bt.count(), 500)
        self.assertIsNotNone(bt.bt)
        # Remove third 250
        bt.remove_from_index(index[500:750])
        self.assertEqual(bt.count(), 250)
        self.assertIsNotNone(bt.bt)
        # Remove final 250
        bt.remove_from_index(index[750:])
        self.assertEqual(bt.count(), 0)
        self.assertIsNone(bt.bt)
Пример #15
0
 def test_save_model_with_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex('some_file.npz')
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_true(m_savez.called)
     nose.tools.assert_equal(m_savez.call_count, 1)
Пример #16
0
 def test_save_model_no_cache(self, m_savez):
     bt = SkLearnBallTreeHashIndex()
     m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256)
     bt.build_index(m)
     nose.tools.assert_false(m_savez.called)