예제 #1
0
def test_kb_valid_entities(nlp):
    """Test the valid construction of a KB with 3 entities and two aliases"""
    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)

    # adding entities
    mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
    mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
    mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])

    # adding aliases
    mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
    mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])

    # test the size of the corresponding KB
    assert mykb.get_size_entities() == 3
    assert mykb.get_size_aliases() == 2

    # test retrieval of the entity vectors
    assert mykb.get_vector("Q1") == [8, 4, 3]
    assert mykb.get_vector("Q2") == [2, 1, 0]
    assert mykb.get_vector("Q3") == [-1, -6, 5]

    # test retrieval of prior probabilities
    assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
    assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
    assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
    assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
예제 #2
0
def test_vocab_serialization(nlp):
    """Test that string information is retained across storage"""
    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)

    # adding entities
    mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
    q2_hash = mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
    mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])

    # adding aliases
    mykb.add_alias(alias="douglas",
                   entities=["Q2", "Q3"],
                   probabilities=[0.4, 0.1])
    adam_hash = mykb.add_alias(alias="adam",
                               entities=["Q2"],
                               probabilities=[0.9])

    candidates = mykb.get_alias_candidates("adam")
    assert len(candidates) == 1
    assert candidates[0].entity == q2_hash
    assert candidates[0].entity_ == "Q2"
    assert candidates[0].alias == adam_hash
    assert candidates[0].alias_ == "adam"

    with make_tempdir() as d:
        mykb.to_disk(d / "kb")
        kb_new_vocab = KnowledgeBase(Vocab(), entity_vector_length=1)
        kb_new_vocab.from_disk(d / "kb")

        candidates = kb_new_vocab.get_alias_candidates("adam")
        assert len(candidates) == 1
        assert candidates[0].entity == q2_hash
        assert candidates[0].entity_ == "Q2"
        assert candidates[0].alias == adam_hash
        assert candidates[0].alias_ == "adam"

        assert kb_new_vocab.get_vector("Q2") == [2]
        assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)