def test_kb_to_bytes():
    # Test that the KB's to_bytes method works correctly
    nlp = English()
    kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
    kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
    kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3])
    kb_1.add_alias(alias="Russ Cochran",
                   entities=["Q2146908"],
                   probabilities=[0.8])
    kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5])
    kb_1.add_alias(alias="Randomness",
                   entities=["Q66", "Q2146908"],
                   probabilities=[0.1, 0.2])
    assert kb_1.contains_alias("Russ Cochran")
    kb_bytes = kb_1.to_bytes()
    kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
    assert not kb_2.contains_alias("Russ Cochran")
    kb_2 = kb_2.from_bytes(kb_bytes)
    # check that both KBs are exactly the same
    assert kb_1.get_size_entities() == kb_2.get_size_entities()
    assert kb_1.entity_vector_length == kb_2.entity_vector_length
    assert kb_1.get_entity_strings() == kb_2.get_entity_strings()
    assert kb_1.get_vector("Q2146908") == kb_2.get_vector("Q2146908")
    assert kb_1.get_vector("Q66") == kb_2.get_vector("Q66")
    assert kb_2.contains_alias("Russ Cochran")
    assert kb_1.get_size_aliases() == kb_2.get_size_aliases()
    assert kb_1.get_alias_strings() == kb_2.get_alias_strings()
    assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(
        kb_2.get_alias_candidates("Russ Cochran"))
    assert len(kb_1.get_alias_candidates("Randomness")) == len(
        kb_2.get_alias_candidates("Randomness"))
def test_kb_set_entities(nlp):
    """Test that set_entities entirely overwrites the previous set of entities"""
    v = [5, 6, 7, 8]
    v1 = [1, 1, 1, 0]
    v2 = [2, 2, 2, 3]
    kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
    kb1.set_entities(["E0"], [1], [v])
    assert kb1.get_entity_strings() == ["E0"]
    kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2])
    assert set(kb1.get_entity_strings()) == {"E1", "E2"}
    assert kb1.get_vector("E1") == v1
    assert kb1.get_vector("E2") == v2
    with make_tempdir() as d:
        kb1.to_disk(d / "kb")
        kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
        kb2.from_disk(d / "kb")
        assert set(kb2.get_entity_strings()) == {"E1", "E2"}
        assert kb2.get_vector("E1") == v1
        assert kb2.get_vector("E2") == v2
def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
    """Create a blank model with the specified vocab, set up the pipeline and train the entity linker.
    The `vocab` should be the one used during creation of the KB."""
    vocab = Vocab().from_disk(vocab_path)
    # create blank Language class with correct vocab
    nlp = spacy.blank("en", vocab=vocab)
    nlp.vocab.vectors.name = "spacy_pretrained_vectors"
    print("Created blank 'en' model with vocab from '%s'" % vocab_path)

    # create the built-in pipeline components and add them to the pipeline
    # nlp.create_pipe works for built-ins that are registered with spaCy
    if "entity_linker" not in nlp.pipe_names:
        entity_linker = nlp.create_pipe("entity_linker")
        kb = KnowledgeBase(vocab=nlp.vocab)
        kb.load_bulk(kb_path)
        print("Loaded Knowledge Base from '%s'" % kb_path)
        entity_linker.set_kb(kb)
        nlp.add_pipe(entity_linker, last=True)
    else:
        entity_linker = nlp.get_pipe("entity_linker")
        kb = entity_linker.kb

    # make sure the annotated examples correspond to known identifiers in the knowlege base
    kb_ids = kb.get_entity_strings()
    for text, annotation in TRAIN_DATA:
        for offset, kb_id_dict in annotation["links"].items():
            new_dict = {}
            for kb_id, value in kb_id_dict.items():
                if kb_id in kb_ids:
                    new_dict[kb_id] = value
                else:
                    print("Removed", kb_id,
                          "from training because it is not in the KB.")
            annotation["links"][offset] = new_dict

    # get names of other pipes to disable them during training
    other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
    with nlp.disable_pipes(*other_pipes):  # only train entity linker
        # reset and initialize the weights randomly
        optimizer = nlp.begin_training()
        for itn in range(n_iter):
            random.shuffle(TRAIN_DATA)
            losses = {}
            # batch up the examples using spaCy's minibatch
            batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
            for batch in batches:
                texts, annotations = zip(*batch)
                nlp.update(
                    texts,  # batch of texts
                    annotations,  # batch of annotations
                    drop=0.2,  # dropout - make it harder to memorise data
                    losses=losses,
                    sgd=optimizer,
                )
            print(itn, "Losses", losses)

    # test the trained model
    _apply_model(nlp)

    # save model to output directory
    if output_dir is not None:
        output_dir = Path(output_dir)
        if not output_dir.exists():
            output_dir.mkdir()
        nlp.to_disk(output_dir)
        print()
        print("Saved model to", output_dir)

        # test the saved model
        print("Loading from", output_dir)
        nlp2 = spacy.load(output_dir)
        _apply_model(nlp2)
Beispiel #4
0
    def index_documents(self, documents):
        entities_list = []
        token_list = []
        entities_dic = {}
        tokens_dic = {}
        index_entities = {}
        index_tokens = {}
        new_token_dict = {}
        new_entity_dict = {}
        idf_token_norm = {}
        idf_entities_norm = {}
        tf_tokens_norm = {}
        tf_entities_norm = {}
        nlp = spacy.load("en_core_web_sm")
        kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
        entity_linker = EntityLinker(nlp.vocab)
        entity_linker.set_kb(kb)
        #kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
        kb.add_entity(entity="glass insurance",
                      freq=6,
                      entity_vector=[0, 3, 5])
        kb.add_entity(entity="robber insurance",
                      freq=342,
                      entity_vector=[1, 9, -3])
        a = kb.get_entity_strings()
        #entity_linker = EntityLinker(nlp.vocab)
        #entity_linker.set_kb(kb)
        #print(a)
        tol_number_doc = len(documents)
        #print(tol_number_doc)
        for doc_id, value in documents.items():
            doc = nlp(value)
            entities_list = []
            token_list = []
            entities_dic = {}
            tokens_dic = {}
            counter_document = {}
            counter_entity = {}
            #print(doc)
            for ent in doc.ents:
                entities_list.append(ent.text)
            #print(entities_list)
            for ent in doc.ents:
                counter_entity[ent.text] = entities_list.count(ent.text)

                entities_dic[ent.text] = {
                    doc_id: entities_list.count(ent.text)
                }
            #print(counter_entity)
            index_entities[doc_id] = entities_dic
            #print(index_entities)

            for token in doc:
                if token.is_stop == False and token.is_punct == False:
                    token_list.append(token.text)

            for token in doc:
                if token.is_stop == False and token.is_punct == False:
                    counter_document[token.text] = token_list.count(token.text)

            for ent in entities_list:
                ent = ent.split(' ')
                if len(ent) == 1:
                    if ent[0] in counter_document.keys():
                        if counter_document[ent[0]] - counter_entity[
                                ent[0]] >= 1:
                            counter_document[
                                ent[0]] = counter_document[ent[0]] - 1
                        else:
                            del counter_document[ent[0]]
            # print(counter_document)

            for token in doc:
                if token.is_stop == False and token.is_punct == False:
                    if token.text in counter_document.keys():
                        tokens_dic[token.text] = {
                            doc_id: counter_document[token.text]
                        }
            #print(counter_document)
            index_tokens[doc_id] = tokens_dic
            print(index_tokens)

        for doc_id in index_entities.keys():
            for key1 in index_entities[doc_id].keys():
                if key1 not in new_entity_dict.keys():
                    new_entity_dict[key1] = {}
                for key2 in index_entities[doc_id][key1].keys():
                    new_entity_dict[key1][key2] = index_entities[doc_id][key1][
                        key2]
        #print(new_entity_dict)
        for doc_id in index_tokens.keys():
            for key1 in index_tokens[doc_id].keys():
                if key1 not in new_token_dict.keys():
                    new_token_dict[key1] = {}
                for key2 in index_tokens[doc_id][key1].keys():
                    new_token_dict[key1][key2] = index_tokens[doc_id][key1][
                        key2]
        #print(new_token_dict)
        for ent, counter in new_entity_dict.items():
            #print(counter.values())
            for j in counter.keys():
                counter[j] = 1 + log(counter[j])
            #print(counter)
            tf_entities_norm[ent] = counter
        print(tf_entities_norm)
        for token, counter in new_token_dict.items():
            #print(counter.values())
            for j in counter.keys():
                counter[j] = 1 + log(1 + log(counter[j]))
            #print(counter)
            tf_tokens_norm[token] = counter

        #print(tf_tokens_norm)
        for ent, counter in new_entity_dict.items():
            tol_ctain_ent = len(counter)
            idf_entities_norm[ent] = 1.0 + log(tol_number_doc /
                                               (1.0 + tol_ctain_ent))
        #print(idf_entities_norm)

        for token, counter in new_token_dict.items():
            tol_ctain_token = len(counter)
            idf_token_norm[token] = 1.0 + log(tol_number_doc /
                                              (1.0 + tol_ctain_token))
        #print(idf_token_norm)

        # print(tf_tokens_norm)
        # print(tf_entities_norm)
        self.tf_tokens = tf_tokens_norm
        self.tf_entities = tf_entities_norm
        self.idf_tokens = idf_token_norm
        self.idf_entities = idf_entities_norm