def test_model_reload(self): fd, fp = tempfile.mkstemp('.npz') os.close(fd) os.remove(fp) # shouldn't exist before construction try: bt = SkLearnBallTreeHashIndex(fp) m = numpy.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt.build_index(m) q = numpy.random.randint(0, 2, 256).astype(bool) bt_neighbors, bt_dists = bt.nn(q, 10) bt2 = SkLearnBallTreeHashIndex(fp) bt2_neighbors, bt2_dists = bt2.nn(q, 10) nose.tools.assert_is_not(bt, bt2) nose.tools.assert_is_not(bt.bt, bt2.bt) numpy.testing.assert_equal(bt2_neighbors, bt_neighbors) numpy.testing.assert_equal(bt2_dists, bt_dists) finally: os.remove(fp)
def test_load_model(self): # Create two index instances, building model with one, and loading # the other with the cache of the first instance. Each should have # distinct model instances, but should otherwise have equal model # values and parameters. cache_element = DataMemoryElement() bt1 = SkLearnBallTreeHashIndex(cache_element, random_seed=0) m = np.random.randint(0, 2, 1000 * 256).reshape(1000, 256) bt1.build_index(m) bt2 = SkLearnBallTreeHashIndex(cache_element) self.assertIsNotNone(bt2.bt) q = np.random.randint(0, 2, 256).astype(bool) bt_neighbors, bt_dists = bt1.nn(q, 10) bt2_neighbors, bt2_dists = bt2.nn(q, 10) self.assertIsNot(bt1, bt2) self.assertIsNot(bt1.bt, bt2.bt) np.testing.assert_equal(bt2_neighbors, bt_neighbors) np.testing.assert_equal(bt2_dists, bt_dists)