예제 #1
0
from capreolus.utils.common import download_file
from capreolus.utils.trec import load_qrels

from capreolus.benchmark.codesearchnet import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark
from capreolus.benchmark.codesearchnet import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark
from capreolus.collection.codesearchnet import CodeSearchNet as CodeSearchNetCollection
from capreolus.collection.covid import COVID as CovidCollection
from capreolus.benchmark.covid import COVID as CovidBenchmark

from capreolus.tests.common_fixtures import tmpdir_as_cache
from capreolus.utils.common import remove_newline
from capreolus.utils.loginit import get_logger

logger = get_logger(__name__)

benchmarks = set(module_registry.get_module_names("benchmark"))


@pytest.mark.parametrize("benchmark_name", benchmarks)
@pytest.mark.download
def test_benchmark_creatable(tmpdir_as_cache, benchmark_name):
    benchmark = Benchmark.create(benchmark_name)
    if hasattr(benchmark, "download_if_missing"):
        benchmark.download_if_missing()


@pytest.mark.download
def test_csn_corpus_benchmark_downloadifmissing():
    for lang in ["ruby"]:
        logger.info(f"testing {lang}")
        cfg = {"name": "codesearchnet_corpus", "lang": lang}
예제 #2
0
import pytest

from capreolus import Benchmark, Task, module_registry
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache

tasks = set(module_registry.get_module_names("task"))


@pytest.mark.parametrize("task_name", tasks)
def test_task_creatable(tmpdir_as_cache, dummy_index, task_name):
    provide = {
        "index": dummy_index,
        "benchmark": Benchmark.create("dummy"),
        "collection": dummy_index.collection
    }
    task = Task.create(task_name, provide=provide)
예제 #3
0
from capreolus.benchmark import DummyBenchmark
from capreolus.collection import DummyCollection
from capreolus.extractor.bagofwords import BagOfWords
from capreolus.extractor.deeptileextractor import DeepTileExtractor
from capreolus.extractor.embedtext import EmbedText
from capreolus.extractor.slowembedtext import SlowEmbedText
from capreolus.index import AnseriniIndex
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache
from capreolus.tokenizer import AnseriniTokenizer
from capreolus.utils.exceptions import MissingDocError
from capreolus.extractor.bertpassage import BertPassage

MAXQLEN = 8
MAXDOCLEN = 7

extractors = set(module_registry.get_module_names("extractor"))


@pytest.mark.parametrize("extractor_name", extractors)
def test_extractor_creatable(tmpdir_as_cache, dummy_index, extractor_name):
    benchmark = DummyBenchmark()
    provide = {
        "index": dummy_index,
        "collection": dummy_index.collection,
        "benchmark": benchmark
    }
    extractor = Extractor.create(extractor_name, provide=provide)


def test_embedtext_id2vec(monkeypatch):
    def fake_load_embeddings(self):
예제 #4
0
    def list_modules(self):
        for module_type in module_registry.get_module_types():
            print(f"module type={module_type}")

            for module_name in module_registry.get_module_names(module_type):
                print(f"       name={module_name}")
예제 #5
0
import pytest

from capreolus import module_registry
from capreolus.collection import DummyCollection
from capreolus.index import Index
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache

indexs = set(module_registry.get_module_names("index"))


@pytest.mark.parametrize("index_name", indexs)
def test_create_index(tmpdir_as_cache, index_name):
    provide = {"collection": DummyCollection()}
    index = Index.create(index_name, provide=provide)
    assert not index.exists()
    index.create_index()
    assert index.exists()


def test_anserini_get_docs(tmpdir_as_cache, dummy_index):
    docs = dummy_index.get_docs(["LA010189-0001"])
    assert docs == [
        "Dummy Dummy Dummy Hello world, greetings from outer space!"
    ]
    docs = dummy_index.get_docs(["LA010189-0001", "LA010189-0002"])
    assert docs == [
        "Dummy Dummy Dummy Hello world, greetings from outer space!",
        "Dummy LessDummy Hello world, greetings from outer space!",
    ]

예제 #6
0
import os
import shutil

import pytest

from capreolus import Collection, constants, module_registry
from capreolus.tests.common_fixtures import tmpdir_as_cache

collections = set(module_registry.get_module_names("collection"))


@pytest.mark.parametrize("collection_name", collections)
def test_collection_creatable(tmpdir_as_cache, collection_name):
    collection = Collection.create(collection_name)


@pytest.mark.parametrize("collection_name", collections)
@pytest.mark.download
def test_collection_downloadable(tmpdir_as_cache, collection_name):
    collection = Collection.create(collection_name)
    path = collection.find_document_path()

    # check for /tmp to reduce the impact of an invalid constants["CACHE_BASE_PATH"]
    if path.startswith("/tmp") and path.startswith(
            constants["CACHE_BASE_PATH"].as_posix()):
        if os.path.exists(path):
            shutil.rmtree(path)
예제 #7
0
import os

import numpy as np
import pytest

from capreolus import module_registry
from capreolus.benchmark import DummyBenchmark
from capreolus.searcher.anserini import BM25, BM25Grid, Searcher
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache
from capreolus.utils.trec import load_trec_topics

skip_searchers = {"bm25staticrob04yang19", "BM25Grid", "BM25Postprocess", "axiomatic"}
searchers = set(module_registry.get_module_names("searcher")) - skip_searchers


@pytest.mark.parametrize("searcher_name", searchers)
def test_searcher_runnable(tmpdir_as_cache, tmpdir, dummy_index, searcher_name):
    topics_fn = DummyBenchmark.topic_file
    searcher = Searcher.create(searcher_name, provide={"index": dummy_index})
    output_dir = searcher.query_from_file(topics_fn, os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name))
    assert os.path.exists(os.path.join(output_dir, "done"))


@pytest.mark.parametrize("searcher_name", searchers)
def test_searcher_query(tmpdir_as_cache, tmpdir, dummy_index, searcher_name):
    topics_fn = DummyBenchmark.topic_file
    query = list(load_trec_topics(topics_fn)["title"].values())[0]
    nhits = 1
    searcher = Searcher.create(searcher_name, config={"hits": nhits}, provide={"index": dummy_index})
    results = searcher.query(query)
    if searcher_name == "SPL":
예제 #8
0
from capreolus.reranker.DSSM import DSSM
from capreolus.reranker.HINT import HINT
from capreolus.reranker.KNRM import KNRM
from capreolus.reranker.PACRR import PACRR
from capreolus.reranker.POSITDRMM import POSITDRMM
from capreolus.reranker.CDSSM import CDSSM
from capreolus.reranker.TFBERTMaxP import TFBERTMaxP
from capreolus.reranker.TFKNRM import TFKNRM
from capreolus.reranker.TK import TK
from capreolus.sampler import TrainTripletSampler, TrainPairSampler, PredSampler
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache
from capreolus.reranker.TFVanillaBert import TFVanillaBERT
from capreolus.reranker.birch import Birch
from capreolus.reranker.parade import TFParade

rerankers = set(module_registry.get_module_names("reranker"))


@pytest.mark.parametrize("reranker_name", rerankers)
def test_reranker_creatable(tmpdir_as_cache, dummy_index, reranker_name):
    benchmark = DummyBenchmark()
    provide = {
        "collection": dummy_index.collection,
        "index": dummy_index,
        "benchmark": benchmark
    }
    reranker = Reranker.create(reranker_name, provide=provide)


def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_load_embeddings(self):