Esempio n. 1
0
    def test_dataset(self,
                     vocab,
                     test_set,
                     coded_test_corpus,
                     args,
                     use_elastic=True,
                     use_EARL=False,
                     verbos=True):
        # use_EARL = True
        if use_EARL:
            earlCG = EARLCG(config['EARL']['endpoint'],
                            config['EARL']['cache_path'])

            self.environment.entity_linker = EntityOrderedLinker(
                candidate_generator=earlCG, sorters=[], vocab=vocab)

            self.environment.relation_linker = RelationOrderedLinker(
                candidate_generator=earlCG, sorters=[], vocab=vocab)
        elif use_elastic:
            self.environment.entity_linker = EntityOrderedLinker(
                candidate_generator=ElasticCG(
                    self.elastic, index_name='entity_whole_match_index'),
                sorters=[
                    StringSimilaritySorter(
                        similarity.ngram.NGram(2).distance, True)
                ],
                vocab=vocab)

            self.environment.relation_linker = RelationOrderedLinker(
                candidate_generator=ElasticCG(
                    self.elastic, index_name='relation_whole_match_index'),
                sorters=[
                    StringSimilaritySorter(jellyfish.levenshtein_distance,
                                           False, True),
                    EmbeddingSimilaritySorter(self.word_vectorizer)
                ],
                vocab=vocab)

        self.agent.policy_network.eval()
        total_relation_mrr, total_entity_mrr = [], []
        for idx, qarow in enumerate(test_set):
            reward, relation_mrr, entity_mrr, loss, _ = self.step(
                coded_test_corpus[idx],
                qarow.lower_indicator,
                qarow,
                e=args.e,
                train=False,
                k=args.k)
            if len(qarow.sparql.relations) > 0:
                total_relation_mrr.append(relation_mrr)
            if len(qarow.sparql.entities) > 0:
                total_entity_mrr.append(entity_mrr)

        total_entity_mrr = np.mean(total_entity_mrr)
        total_relation_mrr = np.mean(total_relation_mrr)
        print('entity MRR', total_entity_mrr)
        print('relation MRR', total_relation_mrr)
        self.agent.policy_network.train(False)
        return total_entity_mrr, total_relation_mrr
Esempio n. 2
0
    def link(self, question, e, k, connecting_relations=False, free_relation_match=False, connecting_relation=False):
        if self.environment.entity_linker is None:
            self.environment.entity_linker = EntityOrderedLinker(
                candidate_generator=ElasticCG(self.elastic, index_name='entity_whole_match_index'),
                sorters=[StringSimilaritySorter(similarity.ngram.NGram(2).distance, True, True)],
                vocab=self.vocab)
        if self.environment.relation_linker is None:
            self.environment.relation_linker = RelationOrderedLinker(
                candidate_generator=ElasticCG(self.elastic, index_name='relation_whole_match_index'),
                sorters=[StringSimilaritySorter(jellyfish.levenshtein_distance, False, True),
                         EmbeddingSimilaritySorter(self.word_vectorizer)],
                vocab=self.vocab)

        self.agent.policy_network.eval()
        normalized_question, normalized_question_with_numbers, lower_indicator = QARow.preprocess(question, [], False,
                                                                                                  False)
        coded_normalized_question = [self.vocab.getIndex(word, 0) for word in normalized_question]

        rewards, action_log_probs, action_probs, actions, split_actions = [], [], [], [], []
        self.environment.init(coded_normalized_question, lower_indicator)
        self.agent.init()
        states = []
        state = self.environment.state
        while True:
            states.append(state)
            action, action_log_prob, action_prob, split_action = self.agent.select_action(state, e, False)
            split_actions.append(split_action)
            actions.append(int(action))
            action_log_probs.append(action_log_prob)
            action_probs.append(action_prob)
            new_state, done, result = self.environment.link(action, int(split_action >= 0.5), k,
                                                            question,
                                                            normalized_question_with_numbers,
                                                            connecting_relations,
                                                            free_relation_match,
                                                            connecting_relation)
            state = new_state
            if done:
                break
        return result
Esempio n. 3
0
        phrases = [item for item in phrases if len(item["chunk"]) > 1]
        return phrases
    return []


# dataset = LC_QuAD(config['lc_quad']['train'], config['lc_quad']['test'], config['lc_quad']['vocab'],
#                   False, False)
# dataset = Qald_7_ml(config['qald_7_ml']['train'], config['qald_7_ml']['test'], config['qald_7_ml']['vocab'],
#                           False, False)
dataset = Qald_6_ml(config['qald_6_ml']['train'], config['qald_6_ml']['test'],
                    config['qald_6_ml']['vocab'], False, False)

if __name__ == '__main__':
    elastic = Elastic(config['elastic']['server'])
    entity_linker = EntityOrderedLinker(
        candidate_generator=ElasticCG(elastic,
                                      index_name='entity_whole_match_index'),
        sorters=[
            StringSimilaritySorter(similarity.ngram.NGram(2).distance, True)
        ],
        vocab=dataset.vocab)

    relation_linker = RelationOrderedLinker(
        candidate_generator=ElasticCG(elastic,
                                      index_name='relation_whole_match_index'),
        sorters=[
            StringSimilaritySorter(jellyfish.levenshtein_distance, False,
                                   True),
            #EmbeddingSimilaritySorter(dataset.word_vectorizer)
        ],
        vocab=dataset.vocab)