예제 #1
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_index_with_storage_of_one():
    pool = Pool.from_file(FLOATS_1, EVectorComponentType.Float, 1)
    hnsw = Hnsw()
    hnsw.build(pool, EDistance.DotProduct)
    neighbors = hnsw.get_nearest([0], 10, 10)
    assert len(neighbors) == 1
    assert neighbors[0][0] == 0
예제 #2
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_index(pool_params, distance):
    pool = Pool.from_file(pool_params['pool_file'],
                          pool_params['component_type'], 10)
    hnsw = Hnsw()
    hnsw.build(pool,
               distance,
               level_size_decay=2,
               max_neighbors=5,
               search_neighborhood_size=30,
               batch_size=10,
               upper_level_batch_size=10,
               num_exact_candidates=10,
               num_threads=1)
    index_path = yatest.common.test_output_path('index')
    hnsw.save(index_path)
    hnsw.load(index_path, pool, distance)
    neighbors = hnsw.get_nearest(pool.get_item(0), 20, 40)
    log_path = yatest.common.test_output_path('log')
    with open(log_path, 'w') as log_file:
        for neighbor in neighbors:
            print(neighbor, file=log_file)
    return [
        yatest.common.canonical_file(index_path, local=True),
        yatest.common.canonical_file(log_path, local=True)
    ]
예제 #3
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_pool(pool_params):
    pool = Pool.from_file(pool_params['pool_file'],
                          pool_params['component_type'], 10)
    log_path = yatest.common.test_output_path('log')
    with open(log_path, 'w') as log_file:
        print(pool.get_num_items(), file=log_file)
        print(list(pool.get_item(0)), file=log_file)
    return [yatest.common.canonical_file(log_path, local=True)]
예제 #4
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_compare_load(pool_params):
    pool_1 = Pool.from_file(pool_params['pool_file'],
                            pool_params['component_type'], 10)
    f = open(pool_params['pool_file'], "rb")
    array = f.read()
    pool_2 = Pool.from_bytes(array, pool_params['component_type'], 10)
    assert pool_1.get_num_items() == pool_2.get_num_items()
    assert pool_1.dimension == pool_2.dimension
    for i in range(pool_1.get_num_items()):
        for j in range(pool_1.dimension):
            assert pool_1.get_item(i)[j] == pool_2.get_item(i)[j]
예제 #5
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_online_hnsw_pool(pool_params):
    pool = Pool.from_file(pool_params['pool_file'],
                          pool_params['component_type'], 10)
    online_hnsw = OnlineHnsw(pool_params['component_type'], 10,
                             EDistance.DotProduct)
    for i in range(pool.get_num_items()):
        online_hnsw.add_item(pool.get_item(i))
    assert online_hnsw.get_num_items() == pool.get_num_items()
    for i in range(pool.get_num_items()):
        online_hnsw_item = online_hnsw.get_item(i)
        item = pool.get_item(i)
        assert np.all(online_hnsw_item == item)
예제 #6
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_save_load(pool_params, distance):
    pool = Pool.from_file(pool_params['pool_file'],
                          pool_params['component_type'], 10)
    hnsw_1 = Hnsw()
    hnsw_1.build(pool,
                 distance,
                 max_neighbors=5,
                 search_neighborhood_size=30,
                 batch_size=10,
                 num_exact_candidates=10,
                 num_threads=1)
    neighbors_before_save = hnsw_1.get_nearest(pool.get_item(0), 20, 40)
    index_path = yatest.common.test_output_path('index')
    hnsw_1.save(index_path)
    hnsw_2 = Hnsw()
    hnsw_2.load(index_path, pool, distance)
    neighbors_after_load = hnsw_2.get_nearest(pool.get_item(0), 20, 40)
    assert neighbors_before_save == neighbors_after_load
예제 #7
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_online_hnsw_index(pool_params, distance):
    pool = Pool.from_file(pool_params['pool_file'],
                          pool_params['component_type'], 10)
    online_hnsw = OnlineHnsw(pool_params['component_type'],
                             10,
                             distance,
                             level_size_decay=2,
                             max_neighbors=5,
                             search_neighborhood_size=50)
    for i in range(pool.get_num_items()):
        online_hnsw.add_item(pool.get_item(i))
    online_hnsw_1_neighbors = online_hnsw.get_nearest(online_hnsw.get_item(0),
                                                      20)
    online_hnsw_2_neighbors = online_hnsw.get_nearest_and_add_item(
        online_hnsw.get_item(0))
    assert online_hnsw.get_num_items() == pool.get_num_items() + 1
    assert len(online_hnsw_2_neighbors) == 50
    online_hnsw_2_neighbors = online_hnsw_2_neighbors[:20]
    assert online_hnsw_1_neighbors == online_hnsw_2_neighbors
예제 #8
0
파일: test.py 프로젝트: mjjohns1/catboost
def test_index_with_empty_storage():
    pool = Pool.from_file(FLOATS_0, EVectorComponentType.Float, 1)
    hnsw = Hnsw()
    hnsw.build(pool, EDistance.DotProduct)
    assert len(hnsw.get_nearest([0], 10, 10)) == 0