def test_index_with_storage_of_one(): pool = Pool.from_file(FLOATS_1, EVectorComponentType.Float, 1) hnsw = Hnsw() hnsw.build(pool, EDistance.DotProduct) neighbors = hnsw.get_nearest([0], 10, 10) assert len(neighbors) == 1 assert neighbors[0][0] == 0
def test_index(pool_params, distance): pool = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) hnsw = Hnsw() hnsw.build(pool, distance, level_size_decay=2, max_neighbors=5, search_neighborhood_size=30, batch_size=10, upper_level_batch_size=10, num_exact_candidates=10, num_threads=1) index_path = yatest.common.test_output_path('index') hnsw.save(index_path) hnsw.load(index_path, pool, distance) neighbors = hnsw.get_nearest(pool.get_item(0), 20, 40) log_path = yatest.common.test_output_path('log') with open(log_path, 'w') as log_file: for neighbor in neighbors: print(neighbor, file=log_file) return [ yatest.common.canonical_file(index_path, local=True), yatest.common.canonical_file(log_path, local=True) ]
def test_pool(pool_params): pool = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) log_path = yatest.common.test_output_path('log') with open(log_path, 'w') as log_file: print(pool.get_num_items(), file=log_file) print(list(pool.get_item(0)), file=log_file) return [yatest.common.canonical_file(log_path, local=True)]
def test_compare_load(pool_params): pool_1 = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) f = open(pool_params['pool_file'], "rb") array = f.read() pool_2 = Pool.from_bytes(array, pool_params['component_type'], 10) assert pool_1.get_num_items() == pool_2.get_num_items() assert pool_1.dimension == pool_2.dimension for i in range(pool_1.get_num_items()): for j in range(pool_1.dimension): assert pool_1.get_item(i)[j] == pool_2.get_item(i)[j]
def test_online_hnsw_pool(pool_params): pool = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) online_hnsw = OnlineHnsw(pool_params['component_type'], 10, EDistance.DotProduct) for i in range(pool.get_num_items()): online_hnsw.add_item(pool.get_item(i)) assert online_hnsw.get_num_items() == pool.get_num_items() for i in range(pool.get_num_items()): online_hnsw_item = online_hnsw.get_item(i) item = pool.get_item(i) assert np.all(online_hnsw_item == item)
def test_save_load(pool_params, distance): pool = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) hnsw_1 = Hnsw() hnsw_1.build(pool, distance, max_neighbors=5, search_neighborhood_size=30, batch_size=10, num_exact_candidates=10, num_threads=1) neighbors_before_save = hnsw_1.get_nearest(pool.get_item(0), 20, 40) index_path = yatest.common.test_output_path('index') hnsw_1.save(index_path) hnsw_2 = Hnsw() hnsw_2.load(index_path, pool, distance) neighbors_after_load = hnsw_2.get_nearest(pool.get_item(0), 20, 40) assert neighbors_before_save == neighbors_after_load
def test_online_hnsw_index(pool_params, distance): pool = Pool.from_file(pool_params['pool_file'], pool_params['component_type'], 10) online_hnsw = OnlineHnsw(pool_params['component_type'], 10, distance, level_size_decay=2, max_neighbors=5, search_neighborhood_size=50) for i in range(pool.get_num_items()): online_hnsw.add_item(pool.get_item(i)) online_hnsw_1_neighbors = online_hnsw.get_nearest(online_hnsw.get_item(0), 20) online_hnsw_2_neighbors = online_hnsw.get_nearest_and_add_item( online_hnsw.get_item(0)) assert online_hnsw.get_num_items() == pool.get_num_items() + 1 assert len(online_hnsw_2_neighbors) == 50 online_hnsw_2_neighbors = online_hnsw_2_neighbors[:20] assert online_hnsw_1_neighbors == online_hnsw_2_neighbors
def test_index_with_empty_storage(): pool = Pool.from_file(FLOATS_0, EVectorComponentType.Float, 1) hnsw = Hnsw() hnsw.build(pool, EDistance.DotProduct) assert len(hnsw.get_nearest([0], 10, 10)) == 0