from capreolus.utils.common import download_file from capreolus.utils.trec import load_qrels from capreolus.benchmark.codesearchnet import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark from capreolus.benchmark.codesearchnet import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark from capreolus.collection.codesearchnet import CodeSearchNet as CodeSearchNetCollection from capreolus.collection.covid import COVID as CovidCollection from capreolus.benchmark.covid import COVID as CovidBenchmark from capreolus.tests.common_fixtures import tmpdir_as_cache from capreolus.utils.common import remove_newline from capreolus.utils.loginit import get_logger logger = get_logger(__name__) benchmarks = set(module_registry.get_module_names("benchmark")) @pytest.mark.parametrize("benchmark_name", benchmarks) @pytest.mark.download def test_benchmark_creatable(tmpdir_as_cache, benchmark_name): benchmark = Benchmark.create(benchmark_name) if hasattr(benchmark, "download_if_missing"): benchmark.download_if_missing() @pytest.mark.download def test_csn_corpus_benchmark_downloadifmissing(): for lang in ["ruby"]: logger.info(f"testing {lang}") cfg = {"name": "codesearchnet_corpus", "lang": lang}
import pytest from capreolus import Benchmark, Task, module_registry from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache tasks = set(module_registry.get_module_names("task")) @pytest.mark.parametrize("task_name", tasks) def test_task_creatable(tmpdir_as_cache, dummy_index, task_name): provide = { "index": dummy_index, "benchmark": Benchmark.create("dummy"), "collection": dummy_index.collection } task = Task.create(task_name, provide=provide)
from capreolus.benchmark import DummyBenchmark from capreolus.collection import DummyCollection from capreolus.extractor.bagofwords import BagOfWords from capreolus.extractor.deeptileextractor import DeepTileExtractor from capreolus.extractor.embedtext import EmbedText from capreolus.extractor.slowembedtext import SlowEmbedText from capreolus.index import AnseriniIndex from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache from capreolus.tokenizer import AnseriniTokenizer from capreolus.utils.exceptions import MissingDocError from capreolus.extractor.bertpassage import BertPassage MAXQLEN = 8 MAXDOCLEN = 7 extractors = set(module_registry.get_module_names("extractor")) @pytest.mark.parametrize("extractor_name", extractors) def test_extractor_creatable(tmpdir_as_cache, dummy_index, extractor_name): benchmark = DummyBenchmark() provide = { "index": dummy_index, "collection": dummy_index.collection, "benchmark": benchmark } extractor = Extractor.create(extractor_name, provide=provide) def test_embedtext_id2vec(monkeypatch): def fake_load_embeddings(self):
def list_modules(self): for module_type in module_registry.get_module_types(): print(f"module type={module_type}") for module_name in module_registry.get_module_names(module_type): print(f" name={module_name}")
import pytest from capreolus import module_registry from capreolus.collection import DummyCollection from capreolus.index import Index from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache indexs = set(module_registry.get_module_names("index")) @pytest.mark.parametrize("index_name", indexs) def test_create_index(tmpdir_as_cache, index_name): provide = {"collection": DummyCollection()} index = Index.create(index_name, provide=provide) assert not index.exists() index.create_index() assert index.exists() def test_anserini_get_docs(tmpdir_as_cache, dummy_index): docs = dummy_index.get_docs(["LA010189-0001"]) assert docs == [ "Dummy Dummy Dummy Hello world, greetings from outer space!" ] docs = dummy_index.get_docs(["LA010189-0001", "LA010189-0002"]) assert docs == [ "Dummy Dummy Dummy Hello world, greetings from outer space!", "Dummy LessDummy Hello world, greetings from outer space!", ]
import os import shutil import pytest from capreolus import Collection, constants, module_registry from capreolus.tests.common_fixtures import tmpdir_as_cache collections = set(module_registry.get_module_names("collection")) @pytest.mark.parametrize("collection_name", collections) def test_collection_creatable(tmpdir_as_cache, collection_name): collection = Collection.create(collection_name) @pytest.mark.parametrize("collection_name", collections) @pytest.mark.download def test_collection_downloadable(tmpdir_as_cache, collection_name): collection = Collection.create(collection_name) path = collection.find_document_path() # check for /tmp to reduce the impact of an invalid constants["CACHE_BASE_PATH"] if path.startswith("/tmp") and path.startswith( constants["CACHE_BASE_PATH"].as_posix()): if os.path.exists(path): shutil.rmtree(path)
import os import numpy as np import pytest from capreolus import module_registry from capreolus.benchmark import DummyBenchmark from capreolus.searcher.anserini import BM25, BM25Grid, Searcher from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache from capreolus.utils.trec import load_trec_topics skip_searchers = {"bm25staticrob04yang19", "BM25Grid", "BM25Postprocess", "axiomatic"} searchers = set(module_registry.get_module_names("searcher")) - skip_searchers @pytest.mark.parametrize("searcher_name", searchers) def test_searcher_runnable(tmpdir_as_cache, tmpdir, dummy_index, searcher_name): topics_fn = DummyBenchmark.topic_file searcher = Searcher.create(searcher_name, provide={"index": dummy_index}) output_dir = searcher.query_from_file(topics_fn, os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name)) assert os.path.exists(os.path.join(output_dir, "done")) @pytest.mark.parametrize("searcher_name", searchers) def test_searcher_query(tmpdir_as_cache, tmpdir, dummy_index, searcher_name): topics_fn = DummyBenchmark.topic_file query = list(load_trec_topics(topics_fn)["title"].values())[0] nhits = 1 searcher = Searcher.create(searcher_name, config={"hits": nhits}, provide={"index": dummy_index}) results = searcher.query(query) if searcher_name == "SPL":
from capreolus.reranker.DSSM import DSSM from capreolus.reranker.HINT import HINT from capreolus.reranker.KNRM import KNRM from capreolus.reranker.PACRR import PACRR from capreolus.reranker.POSITDRMM import POSITDRMM from capreolus.reranker.CDSSM import CDSSM from capreolus.reranker.TFBERTMaxP import TFBERTMaxP from capreolus.reranker.TFKNRM import TFKNRM from capreolus.reranker.TK import TK from capreolus.sampler import TrainTripletSampler, TrainPairSampler, PredSampler from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache from capreolus.reranker.TFVanillaBert import TFVanillaBERT from capreolus.reranker.birch import Birch from capreolus.reranker.parade import TFParade rerankers = set(module_registry.get_module_names("reranker")) @pytest.mark.parametrize("reranker_name", rerankers) def test_reranker_creatable(tmpdir_as_cache, dummy_index, reranker_name): benchmark = DummyBenchmark() provide = { "collection": dummy_index.collection, "index": dummy_index, "benchmark": benchmark } reranker = Reranker.create(reranker_name, provide=provide) def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_load_embeddings(self):