コード例 #1
0
def test_bagofwords_caching(dummy_index, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return Magnitude(None)

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    extractor_cfg = {
        "name": "bagofwords",
        "datamode": "trigram",
        "maxqlen": 4,
        "maxdoclen": 800,
        "usecache": True
    }
    extractor = BagOfWords(extractor_cfg, provide={"index": dummy_index})

    benchmark = DummyBenchmark()

    qids = list(benchmark.qrels.keys())  # ["301"]
    qid = qids[0]
    docids = list(benchmark.qrels[qid].keys())

    assert not extractor.is_state_cached(qids, docids)

    extractor.preprocess(qids, docids, benchmark.topics[benchmark.query_type])

    assert extractor.is_state_cached(qids, docids)

    new_extractor = BagOfWords(extractor_cfg, provide={"index": dummy_index})

    assert new_extractor.is_state_cached(qids, docids)
    new_extractor._build_vocab(qids, docids,
                               benchmark.topics[benchmark.query_type])
コード例 #2
0
def test_slowembedtext_caching(dummy_index, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    extractor_cfg = {
        "name": "slowembedtext",
        "embeddings": "glove6b",
        "zerounk": True,
        "calcidf": True,
        "maxqlen": MAXQLEN,
        "maxdoclen": MAXDOCLEN,
        "usecache": True,
    }
    extractor = SlowEmbedText(extractor_cfg, provide={"index": dummy_index})
    benchmark = DummyBenchmark()

    qids = list(benchmark.qrels.keys())  # ["301"]
    qid = qids[0]
    docids = list(benchmark.qrels[qid].keys())

    assert not extractor.is_state_cached(qids, docids)

    extractor.preprocess(qids, docids, benchmark.topics[benchmark.query_type])

    assert extractor.is_state_cached(qids, docids)

    new_extractor = SlowEmbedText(extractor_cfg,
                                  provide={"index": dummy_index})

    assert new_extractor.is_state_cached(qids, docids)
    new_extractor._build_vocab(qids, docids,
                               benchmark.topics[benchmark.query_type])
コード例 #3
0
def test_train_sampler(monkeypatch, tmpdir):
    benchmark = DummyBenchmark()
    extractor = EmbedText(
        {"tokenizer": {"keepstops": True}}, provide={"collection": benchmark.collection, "benchmark": benchmark}
    )
    training_judgments = benchmark.qrels.copy()
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(training_judgments, training_judgments, extractor)

    def mock_id2vec(*args, **kwargs):
        return {"query": np.array([1, 2, 3, 4]), "posdoc": np.array([1, 1, 1, 1]), "negdoc": np.array([2, 2, 2, 2])}

    monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec)
    dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    for idx, batch in enumerate(dataloader):
        assert len(batch["query"]) == 32
        assert len(batch["posdoc"]) == 32
        assert len(batch["negdoc"]) == 32
        assert np.array_equal(batch["query"][0], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["query"][30], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["posdoc"][0], np.array([1, 1, 1, 1]))
        assert np.array_equal(batch["posdoc"][30], np.array([1, 1, 1, 1]))
        assert np.array_equal(batch["negdoc"][0], np.array([2, 2, 2, 2]))
        assert np.array_equal(batch["negdoc"][30], np.array([2, 2, 2, 2]))

        # Just making sure that the dataloader can do multiple iterations
        if idx > 3:
            break
コード例 #4
0
def test_birch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark()
    reranker = Birch({"trainer": {
        "niters": 1,
        "itersize": 2,
        "batch": 2
    }},
                     provide={
                         "index": dummy_index,
                         "benchmark": benchmark
                     })
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.searcher_scores = {
        "301": {
            "LA010189-0001": 2,
            "LA010189-0002": 1
        }
    }
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)
コード例 #5
0
def test_pred_sampler(monkeypatch, tmpdir):
    benchmark = DummyBenchmark()
    extractor = EmbedText({"tokenizer": {
        "keepstops": True
    }},
                          provide={"collection": benchmark.collection})
    search_run = {"301": {"LA010189-0001": 50, "LA010189-0002": 100}}
    pred_dataset = PredSampler()
    pred_dataset.prepare(benchmark.qrels, search_run, extractor)

    def mock_id2vec(*args, **kwargs):
        return {
            "query": np.array([1, 2, 3, 4]),
            "posdoc": np.array([1, 1, 1, 1])
        }

    monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec)
    dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=2)
    for idx, batch in enumerate(dataloader):
        print(idx, batch)
        assert len(batch["query"]) == 2
        assert len(batch["posdoc"]) == 2
        assert batch.get("negdoc") is None
        assert np.array_equal(batch["query"][0], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["query"][1], np.array([1, 2, 3, 4]))
        assert np.array_equal(batch["posdoc"][0], np.array([1, 1, 1, 1]))
        assert np.array_equal(batch["posdoc"][1], np.array([1, 1, 1, 1]))
コード例 #6
0
def test_slowembedtext_id2vec(monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    extractor_cfg = {
        "name": "slowembedtext",
        "embeddings": "glove6b",
        "zerounk": True,
        "calcidf": True,
        "maxqlen": MAXQLEN,
        "maxdoclen": MAXDOCLEN,
        "usecache": False,
    }
    extractor = SlowEmbedText(extractor_cfg,
                              provide={
                                  "collection": DummyCollection(),
                                  "benchmark": benchmark
                              })

    qids = list(benchmark.qrels.keys())  # ["301"]
    qid = qids[0]
    docids = list(benchmark.qrels[qid].keys())

    extractor.preprocess(qids, docids, benchmark.topics[benchmark.query_type])

    docid1, docid2 = docids[0], docids[1]
    data = extractor.id2vec(qid, docid1, docid2, label=[1, 0])
    q, d1, d2, idf = [data[k] for k in ["query", "posdoc", "negdoc", "idfs"]]

    assert q.shape[0] == idf.shape[0]

    topics = benchmark.topics[benchmark.query_type]
    # emb_path = "glove/light/glove.6B.300d"
    # fullemb = Magnitude(MagnitudeUtils.download_model(emb_path))

    assert len(q) == MAXQLEN
    assert len(d1) == MAXDOCLEN
    assert len(d2) == MAXDOCLEN

    assert len([w for w in q
                if w.sum() != 0]) == len(topics[qid].strip().split()[:MAXQLEN])
    assert len([w for w in d1 if w.sum() != 0]) == len(
        extractor.index.get_doc(docid1).strip().split()[:MAXDOCLEN])
    assert len([w for w in d2 if w.sum() != 0]) == len(
        extractor.index.get_doc(docid2).strip().split()[:MAXDOCLEN])

    # check MissDocError
    error_thrown = False
    try:
        extractor.id2vec(qid, "0000000", "111111", label=[1, 0])
    except MissingDocError as err:
        error_thrown = True
        assert err.related_qid == qid
        assert err.missed_docid == "0000000"

    assert error_thrown
コード例 #7
0
def test_dssm_unigram(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark()
    reranker = DSSM(
        {
            "nhiddens": "56",
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            }
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
コード例 #8
0
def test_reranker_creatable(tmpdir_as_cache, dummy_index, reranker_name):
    benchmark = DummyBenchmark()
    provide = {
        "collection": dummy_index.collection,
        "index": dummy_index,
        "benchmark": benchmark
    }
    reranker = Reranker.create(reranker_name, provide=provide)
コード例 #9
0
def test_extractor_creatable(tmpdir_as_cache, dummy_index, extractor_name):
    benchmark = DummyBenchmark()
    provide = {
        "index": dummy_index,
        "collection": dummy_index.collection,
        "benchmark": benchmark
    }
    extractor = Extractor.create(extractor_name, provide=provide)
コード例 #10
0
def test_searcher_bm25(tmpdir_as_cache, tmpdir, dummy_index):
    searcher = BM25(provide={"index": dummy_index})
    topics_fn = DummyBenchmark().get_topics_file()

    output_dir = searcher.query_from_file(topics_fn, os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name))

    assert output_dir == os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name)

    with open(os.path.join(output_dir, "searcher"), "r") as fp:
        file_contents = fp.readlines()

    assert file_contents == ["301 Q0 LA010189-0001 1 0.139500 Anserini\n", "301 Q0 LA010189-0002 2 0.097000 Anserini\n"]
コード例 #11
0
def test_bertmaxp_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    benchmark = DummyBenchmark({"collection": {"name": "dummy"}})
    reranker = TFBERTMaxP(
        {
            "pretrained": "bert-base-uncased",
            "extractor": {
                "name": "bertpassage",
                "usecache": False,
                "maxseqlen": 32,
                "numpassages": 2,
                "passagelen": 15,
                "stride": 5,
                "index": {
                    "name": "anserini",
                    "indexstops": False,
                    "stemmer": "porter",
                    "collection": {
                        "name": "dummy"
                    }
                },
            },
            "trainer": {
                "name": "tensorflow",
                "batch": 4,
                "niters": 1,
                "itersize": 2,
                "lr": 0.001,
                "validatefreq": 1,
                "usecache": False,
                "tpuname": None,
                "tpuzone": None,
                "storage": None,
                "boardname": "default",
                "loss": "crossentropy",
                "eager": False,
            },
        },
        provide=benchmark,
    )

    reranker.extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                                  benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.bm25_scores = {"301": {"LA010189-0001": 2, "LA010189-0002": 1}}
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainPairSampler()
    train_dataset.prepare(train_run, benchmark.qrels, reranker.extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, reranker.extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, "map")
コード例 #12
0
def test_slowembedtext_creation(monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    index_cfg = {
        "name": "anserini",
        "indexstops": False,
        "stemmer": "porter",
        "collection": {
            "name": "dummy"
        }
    }
    index = AnseriniIndex(index_cfg)

    benchmark = DummyBenchmark()
    extractor_cfg = {
        "name": "slowembedtext",
        "embeddings": "glove6b",
        "zerounk": True,
        "calcidf": True,
        "maxqlen": MAXQLEN,
        "maxdoclen": MAXDOCLEN,
        "usecache": False,
    }
    extractor = SlowEmbedText(extractor_cfg,
                              provide={
                                  "index": index,
                                  "benchmark": benchmark
                              })

    qids = list(benchmark.qrels.keys())  # ["301"]
    qid = qids[0]
    docids = list(benchmark.qrels[qid].keys())
    extractor.preprocess(qids, docids, benchmark.topics[benchmark.query_type])
    expected_vocabs = [
        "lessdummy", "dummy", "doc", "hello", "greetings", "world", "from",
        "outer", "space", "<pad>"
    ]
    expected_stoi = {s: i for i, s in enumerate(expected_vocabs)}

    assert set(extractor.stoi.keys()) == set(expected_stoi.keys())

    assert extractor.embeddings.shape == (len(expected_vocabs), 8)
    for i in range(extractor.embeddings.shape[0]):
        if i == extractor.pad:
            assert extractor.embeddings[i].sum() < 1e-5
            continue

    return extractor
コード例 #13
0
def test_searcher_bm25_grid(tmpdir_as_cache, tmpdir, dummy_index):
    searcher = BM25Grid(provide={"index": dummy_index})
    bs = np.around(np.arange(0.1, 1 + 0.1, 0.1), 1)
    k1s = np.around(np.arange(0.1, 1 + 0.1, 0.1), 1)
    topics_fn = DummyBenchmark().get_topics_file()

    output_dir = searcher.query_from_file(topics_fn, os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name))
    assert output_dir == os.path.join(searcher.get_cache_path(), DummyBenchmark.module_name)

    for k1 in k1s:
        for b in bs:
            assert os.path.exists(os.path.join(output_dir, "searcher_bm25(k1={0},b={1})_default".format(k1, b)))
    assert os.path.exists(os.path.join(output_dir, "done"))
コード例 #14
0
def test_embedtext_creation():
    extractor_cfg = {
        "_name": "embedtext",
        "index": "anserini",
        "tokenizer": "anserini",
        "embeddings": "glove6b",
        "zerounk": True,
        "calcidf": True,
        "maxqlen": MAXQLEN,
        "maxdoclen": MAXDOCLEN,
    }
    extractor = EmbedText(extractor_cfg)

    benchmark = DummyBenchmark({"_fold": "s1", "rundocsonly": False})
    collection = DummyCollection({"_name": "dummy"})

    index_cfg = {"_name": "anserini", "indexstops": False, "stemmer": "porter"}
    index = AnseriniIndex(index_cfg)
    index.modules["collection"] = collection

    tok_cfg = {"_name": "anserini", "keepstops": True, "stemmer": "none"}
    tokenizer = AnseriniTokenizer(tok_cfg)

    extractor.modules["index"] = index
    extractor.modules["tokenizer"] = tokenizer

    qids = list(benchmark.qrels.keys())  # ["301"]
    qid = qids[0]
    docids = list(benchmark.qrels[qid].keys())

    extractor.create(qids, docids, benchmark.topics[benchmark.query_type])

    expected_vocabs = [
        "lessdummy", "dummy", "doc", "hello", "greetings", "world", "from",
        "outer", "space", "<pad>"
    ]
    expected_stoi = {s: i for i, s in enumerate(expected_vocabs)}

    assert set(extractor.stoi.keys()) == set(expected_stoi.keys())

    emb_path = "glove/light/glove.6B.300d"
    fullemb = Magnitude(MagnitudeUtils.download_model(emb_path))
    assert extractor.embeddings.shape == (len(expected_vocabs), fullemb.dim)

    for i in range(extractor.embeddings.shape[0]):
        if i == extractor.pad:
            assert extractor.embeddings[i].sum() < 1e-5
            continue
        s = extractor.itos[i]
        assert (extractor.embeddings[i] - fullemb.query(s)).sum() < 1e-5
    return extractor
コード例 #15
0
def test_tk(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = TK(
        {
            "gradkernels": True,
            "scoretanh": False,
            "singlefc": True,
            "projdim": 32,
            "ffdim": 100,
            "numlayers": 2,
            "numattheads": 4,
            "alpha": 0.5,
            "usemask": False,
            "usemixer": True,
            "finetune": True,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
コード例 #16
0
def test_bagofwords_id2vec(tmpdir, dummy_index):
    benchmark = DummyBenchmark({})
    tok_cfg = {"name": "anserini", "keepstops": True, "stemmer": "none"}
    tokenizer = AnseriniTokenizer(tok_cfg)
    extractor = BagOfWords(
        {
            "name": "bagofwords",
            "datamode": "unigram",
            "maxqlen": 4,
            "maxdoclen": 800,
            "usecache": False
        },
        provide={
            "index": dummy_index,
            "tokenizer": tokenizer,
            "benchmark": benchmark
        },
    )
    extractor.stoi = {extractor.pad_tok: extractor.pad}
    extractor.itos = {extractor.pad: extractor.pad_tok}
    extractor.idf = defaultdict(lambda: 0)
    # extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"], benchmark.topics["title"])

    extractor.qid2toks = {"301": ["dummy", "doc"]}
    extractor.stoi["dummy"] = 1
    extractor.stoi["doc"] = 2
    extractor.itos[1] = "dummy"
    extractor.itos[2] = "doc"
    extractor.docid2toks = {
        "LA010189-0001": [
            "dummy", "dummy", "dummy", "hello", "world", "greetings", "from",
            "outer", "space"
        ],
        "LA010189-0002": [
            "dummy", "dummy", "dummy", "hello", "world", "greetings", "from",
            "outer", "space"
        ],
    }
    transformed = extractor.id2vec("301", "LA010189-0001", "LA010189-0001")
    # stoi only knows about the word 'dummy' and 'doc'. So the transformation of every other word is set as 0

    assert transformed["qid"] == "301"
    assert transformed["posdocid"] == "LA010189-0001"
    assert transformed["negdocid"] == "LA010189-0001"
    assert np.array_equal(transformed["query"], [0, 1, 1])
    assert np.array_equal(transformed["posdoc"], [6, 3, 0])
    assert np.array_equal(transformed["negdoc"], [6, 3, 0])
    assert np.array_equal(transformed["query_idf"], [0, 0, 0])
コード例 #17
0
def test_searcher_query(tmpdir_as_cache, tmpdir, dummy_index, searcher_name):
    topics_fn = DummyBenchmark().get_topics_file()
    query = list([line.strip().split("\t")[1] for line in open(topics_fn)])[0]

    nhits = 1
    searcher = Searcher.create(searcher_name, config={"hits": nhits}, provide={"index": dummy_index})
    results = searcher.query(query)
    if searcher_name == "SPL":
        # if searcher_name != "BM25":
        return

    print(results.values())
    if isinstance(list(results.values())[0], dict):
        assert all(len(d) == nhits for d in results.values())
    else:
        assert len(results) == nhits
コード例 #18
0
def test_deeptiles_extract_segment_short_text(tmpdir, monkeypatch,
                                              dummy_index):
    def fake_magnitude_embedding(*args, **kwargs):
        return Magnitude(None)

    monkeypatch.setattr(DeepTileExtractor, "_get_pretrained_emb",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    # The text is too short for TextTilingTokenizer. Test if the fallback works
    ttt = TextTilingTokenizer(k=6)
    pipeline_config = {
        "name": "deeptiles",
        "passagelen": 30,
        "slicelen": 20,
        "tfchannel": True,
        "tilechannels": 3,
        "index": {
            "collection": {
                "name": "dummy"
            }
        },
    }
    extractor = DeepTileExtractor(pipeline_config,
                                  provide={
                                      "index": dummy_index,
                                      "benchmark": benchmark
                                  })
    s = "But we in it shall be rememberèd We few, we happy few, we band of brothers"
    doc_toks = s.split(" ")
    segments = extractor.extract_segment(doc_toks, ttt)
    assert len(segments) == 1
    # N.B - segments are in all lowercase, special chars (comma) have been removed
    assert segments == [
        "But we in it shall be rememberèd We few, we happy few, we band of brothers"
    ]

    s = (
        "But we in it shall be rememberèd We few, we happy few, we band of brothers. For he to-day that sheds his "
        "blood with me Shall be my brother")
    doc_toks = s.split(" ")

    segments = extractor.extract_segment(doc_toks, ttt)
    assert len(segments) == 2
    assert segments == [
        "But we in it shall be rememberèd We few, we happy few, we band of brothers. For he to-day that",
        "sheds his blood with me Shall be my brother",
    ]
コード例 #19
0
def test_CDSSM(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = CDSSM(
        {
            "nkernel": 3,
            "nfilter": 1,
            "nhiddens": 30,
            "windowsize": 3,
            "dropoutrate": 0,
            "trainer": {
                "niters": 1,
                "itersize": 2,
                "batch": 1
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()
    reranker.searcher_scores = {
        "301": {
            "LA010189-0001": 2,
            "LA010189-0002": 1
        }
    }
    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)
コード例 #20
0
def test_knrm_tf_ce(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        vectors = np.zeros((8, 32))
        stoi = defaultdict(lambda x: 0)
        itos = defaultdict(lambda x: "dummy")

        return vectors, stoi, itos

    monkeypatch.setattr(capreolus.extractor.common,
                        "load_pretrained_embeddings", fake_magnitude_embedding)
    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    reranker = TFKNRM(
        {
            "gradkernels": True,
            "finetune": False,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2,
                "loss": "binary_crossentropy"
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainPairSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best.index")
コード例 #21
0
def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_load_embeddings(self):
        self.embeddings = np.zeros((1, 50))
        self.stoi = {"<pad>": 0}
        self.itos = {v: k for k, v in self.stoi.items()}

    monkeypatch.setattr(EmbedText, "_load_pretrained_embeddings",
                        fake_load_embeddings)

    benchmark = DummyBenchmark()
    reranker = KNRM(
        {
            "gradkernels": True,
            "scoretanh": False,
            "singlefc": True,
            "finetune": False,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)

    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
コード例 #22
0
def test_bagofwords_create(monkeypatch, tmpdir, dummy_index):
    benchmark = DummyBenchmark({})
    extractor = BagOfWords(
        {
            "name": "bagofwords",
            "datamode": "unigram",
            "maxqlen": 4,
            "maxdoclen": 800,
            "usecache": False
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics["title"])
    assert extractor.stoi == {
        "<pad>": 0,
        "dummy": 1,
        "doc": 2,
        "hello": 3,
        "world": 4,
        "greetings": 5,
        "from": 6,
        "outer": 7,
        "space": 8,
        "lessdummy": 9,
    }

    assert extractor.itos == {v: k for k, v in extractor.stoi.items()}
    assert extractor.embeddings == {
        "<pad>": 0,
        "dummy": 1,
        "doc": 2,
        "hello": 3,
        "world": 4,
        "greetings": 5,
        "from": 6,
        "outer": 7,
        "space": 8,
        "lessdummy": 9,
    }
コード例 #23
0
def test_deeptilebars(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return Magnitude(None)

    monkeypatch.setattr(DeepTileExtractor, "_get_pretrained_emb",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    reranker = DeepTileBar(
        {
            "name": "DeepTileBar",
            "passagelen": 30,
            "numberfilter": 3,
            "lstmhiddendim": 3,
            "linearhiddendim1": 32,
            "linearhiddendim2": 16,
            "trainer": {
                "niters": 1,
                "itersize": 4,
                "batch": 2
            },
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")
コード例 #24
0
def test_deeptiles_extract_segment_long_text(tmpdir, monkeypatch, dummy_index):
    def fake_magnitude_embedding(*args, **kwargs):
        return Magnitude(None)

    monkeypatch.setattr(DeepTileExtractor, "_get_pretrained_emb",
                        fake_magnitude_embedding)
    benchmark = DummyBenchmark()
    # nltk.TextTilingTokenizer only works with large blobs of text
    ttt = TextTilingTokenizer(k=6)
    extractor_config = {
        "name": "deeptiles",
        "embeddings": "glove6b",
        "tilechannels": 3,
        "passagelen": 30,
        "slicelen": 20,
        "tfchannel": True,
    }
    extractor = DeepTileExtractor(extractor_config,
                                  provide={
                                      "index": dummy_index,
                                      "benchmark": benchmark
                                  })

    # blob of text with Shakespeare and Shangri La. Should split into two topics
    s = (
        "O that we now had here but one ten thousand of those men in England That do no work to-day. Whats he that "
        "wishes so? My cousin, Westmorland? No, my fair cousin. If we are marked to die, we are enough To do our "
        "country loss; and if to live, The fewer men, the greater share of honour. Gods will! I pray thee, wish"
        " not one man more. Shangri-La is a fictional place described in the 1933 novel Lost Horizon "
        "by British author James Hilton. Hilton describes Shangri-La as a mystical, harmonious valley, gently guided "
        "from a lamasery, enclosed in the western end of the Kunlun Mountains. Shangri-La has become synonymous with "
        "any earthly paradise, particularly a mythical Himalayan utopia – a permanently happy land, isolated from "
        "the world")
    doc_toks = s.split(" ")
    segments = extractor.extract_segment(doc_toks, ttt)
    assert len(segments) == 2

    # The split was determined by nltk.TextTilingTokenizer. Far from perfect
    assert segments == [
        "O that we now had here but one ten thousand of those men in England That do no work to-day. Whats he that wishes so? My cousin, Westmorland? No, my fair cousin. If we are marked to die, we are",
        " enough To do our country loss; and if to live, The fewer men, the greater share of honour. Gods will! I pray thee, wish not one man more. Shangri-La is a fictional place described in the 1933 novel Lost Horizon by British author James Hilton. Hilton describes Shangri-La as a mystical, harmonious valley, gently guided from a lamasery, enclosed in the western end of the Kunlun Mountains. Shangri-La has become synonymous with any earthly paradise, particularly a mythical Himalayan utopia – a permanently happy land, isolated from the world",
    ]
コード例 #25
0
def test_HINT(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
    def fake_magnitude_embedding(*args, **kwargs):
        return np.zeros((1, 8), dtype=np.float32), {0: "<pad>"}, {"<pad>": 0}

    monkeypatch.setattr(SlowEmbedText, "_load_pretrained_embeddings",
                        fake_magnitude_embedding)

    benchmark = DummyBenchmark()
    reranker = HINT(
        {
            "spatialGRU": 2,
            "LSTMdim": 6,
            "kmax": 10,
            "trainer": {
                "niters": 1,
                "itersize": 2,
                "batch": 1
            }
        },
        provide={
            "index": dummy_index,
            "benchmark": benchmark
        },
    )
    extractor = reranker.extractor
    metric = "map"

    extractor.preprocess(["301"], ["LA010189-0001", "LA010189-0002"],
                         benchmark.topics[benchmark.query_type])
    reranker.build_model()

    train_run = {"301": ["LA010189-0001", "LA010189-0002"]}
    train_dataset = TrainTripletSampler()
    train_dataset.prepare(train_run, benchmark.qrels, extractor)
    dev_dataset = PredSampler()
    dev_dataset.prepare(train_run, benchmark.qrels, extractor)
    reranker.trainer.train(reranker, train_dataset,
                           Path(tmpdir) / "train", dev_dataset,
                           Path(tmpdir) / "dev", benchmark.qrels, metric)

    assert os.path.exists(Path(tmpdir) / "train" / "dev.best")