예제 #1
0
    def _process_into_collections(self, indices, data_type):
        """
        Processing all parameters into collections.
        returns:
            NewsWordsCollection and RelationCollection
        """
        def find_feature_vector_for_opinion(opinion_vector_collections, opinion):
            assert(isinstance(opinion_vector_collections, list))

            for collection in opinion_vector_collections:
                assert(isinstance(collection, OpinionVectorCollection))
                if not collection.has_opinion(opinion):
                    continue
                return collection.find_by_opinion(opinion)

            return None

        assert(isinstance(indices, list))

        erc = ExtractedRelationsCollection()
        ntc = NewsTermsCollection()
        for news_index in indices:
            assert(isinstance(news_index, int))

            entity_filepath = self.io.get_entity_filepath(news_index)
            news_filepath = self.io.get_news_filepath(news_index)
            opin_filepath = self.io.get_opinion_input_filepath(news_index)
            neutral_filepath = self.io.get_neutral_filepath(news_index, data_type)

            news = News.from_file(news_filepath,
                                  EntityCollection.from_file(entity_filepath, self.Settings.Stemmer),
                                  stemmer=self.Settings.Stemmer)

            opinions_collections = [OpinionCollection.from_file(neutral_filepath,
                                                                self.io.get_synonyms_collection_filepath(),
                                                                self.Settings.Stemmer)]
            if data_type == DataType.Train:
                opinions_collections.append(OpinionCollection.from_file(opin_filepath,
                                                                        self.io.get_synonyms_collection_filepath(),
                                                                        self.Settings.Stemmer))

            news_terms = NewsTerms.create_from_news(news_index, news, keep_tokens=self.Settings.KeepTokens)

            for relations, opinion in self._extract_relations(opinions_collections, news, news_terms):

                feature_vector = find_feature_vector_for_opinion(self.get_opinion_vector_collection(news_index, data_type),
                                                                 opinion)

                erc.add_news_relations(relations,
                                       opinion,
                                       news_terms,
                                       news_index,
                                       feature_vector)
            ntc.add_news_terms(news_terms)

        return ntc, erc
    def __read_collection(self, io, data_type, settings):
        assert(isinstance(io, RuSentRelNetworkIO))
        assert(isinstance(data_type, unicode))
        assert(isinstance(settings, CommonModelSettings))

        erc = ExtractedRelationsCollection()
        ntc = NewsTermsCollection()
        entities_list = []
        missed_relations_total = 0
        for news_index in io.get_data_indices(data_type):
            assert(isinstance(news_index, int))

            entity_filepath = io.get_entity_filepath(news_index)
            news_filepath = io.get_news_filepath(news_index)
            opin_filepath = io.get_etalon_doc_opins_filepath(news_index)
            neutral_filepath = io.get_neutral_filepath(news_index, data_type)

            entities = EntityCollection.from_file(entity_filepath, settings.Stemmer, self.__synonyms)

            news = News.from_file(news_filepath, entities)

            opinions_collections = [OpinionCollection.from_file(neutral_filepath, self.__synonyms)]
            if data_type == DataType.Train:
                opinions_collections.append(OpinionCollection.from_file(opin_filepath, self.__synonyms))

            news_terms = NewsTerms.create_from_news(news_index, news, keep_tokens=settings.KeepTokens)
            news_terms_helper = NewsTermsHelper(news_terms)

            if DebugKeys.NewsTermsStatisticShow:
                news_terms_helper.debug_statistics()
            if DebugKeys.NewsTermsShow:
                news_terms_helper.debug_show_terms()

            for relations, opinion, opinions in self.__extract_relations(opinions_collections, news, news_terms):
                reversed = ContextModelInitHelper.__find_or_create_reversed_opinion(opinion, opinions_collections)
                missed = erc.add_news_relations(relations=relations,
                                                label=self.__labels_helper.create_label_from_opinions(forward=opinion, backward=reversed),
                                                news_terms=news_terms,
                                                news_index=news_index,
                                                check_relation_is_correct=lambda r: Sample.check_ability_to_create_sample(
                                                    window_size=settings.TermsPerContext,
                                                    relation=r))
                missed_relations_total += missed

            ntc.add_news_terms(news_terms)
            entities_list.append(entities)

        return ntc, erc, entities_list, missed_relations_total
예제 #3
0
    def _process_into_collections(self, indices, entity_indices,
                                  word_embedding, window_size_in_words,
                                  is_train_collection):
        assert (isinstance(indices, list))
        assert (isinstance(word_embedding, Embedding))
        assert (isinstance(is_train_collection, bool))

        rc = ExtractedRelationsCollection()
        nwc = NewsWordsCollection(entity_indices, word_embedding)
        for n in indices:
            assert (type(n) == int)

            entity_filepath = self.io.get_entity_filepath(n)
            news_filepath = self.io.get_news_filepath(n)
            opin_filepath = self.io.get_opinion_input_filepath(n)
            neutral_filepath = self.io.get_neutral_filepath(
                n, is_train_collection)

            news = News.from_file(news_filepath,
                                  EntityCollection.from_file(entity_filepath))

            opinions_collections = [
                OpinionCollection.from_file(neutral_filepath,
                                            self.synonyms_filepath)
            ]
            if is_train_collection:
                opinions_collections.append(
                    OpinionCollection.from_file(opin_filepath,
                                                self.synonyms_filepath))

            news_words = NewsWords(n, news)
            news_descriptor = self.create_news_descriptor(
                n, news, news_words, opinions_collections, is_train_collection)

            rc.add_news_relations(news_descriptor, self.synonyms,
                                  window_size_in_words, is_train_collection)
            nwc.add_news(news_words)

        return nwc, rc
예제 #4
0
#
synonyms = SynonymsCollection.from_file(io_utils.get_synonyms_filepath())

#
# Train
#
root = io_utils.train_root()
for n in io_utils.train_indices():
    entity_filepath = root + "art{}.ann".format(n)
    news_filepath = root + "art{}.txt".format(n)
    opin_filepath = root + "art{}.opin.txt".format(n)
    neutral_filepath = root + "art{}.neut.txt".format(n)

    print neutral_filepath

    entities = EntityCollection.from_file(entity_filepath)
    news = News.from_file(news_filepath, entities)
    opinions = OpinionCollection.from_file(opin_filepath,
                                           io_utils.get_synonyms_filepath())

    neutral_opins = make_neutrals(news, synonyms, opinions)
    neutral_opins.save(neutral_filepath)

#
# Test
#
root = io_utils.test_root()
for n in io_utils.test_indices():
    entity_filepath = path.join(root, "art{}.ann".format(n))
    news_filepath = path.join(root, "art{}.txt".format(n))
    neutral_filepath = path.join(root, "art{}.neut.txt".format(n))
예제 #5
0
    def __init__(self, io, word_embedding, train_indices, test_indices,
                 synonyms_filepath, bag_size, words_per_news,
                 bags_per_minibatch, callback):
        assert (isinstance(io, NetworkIO))
        assert (isinstance(word_embedding, Embedding))
        assert (isinstance(callback, Callback))

        self.io = io
        self.sess = None
        self.train_indices = train_indices
        self.test_indices = test_indices
        self.words_per_news = words_per_news
        self.synonyms_filepath = synonyms_filepath
        self.synonyms = SynonymsCollection.from_file(self.synonyms_filepath)

        # Compute embedding vectors for entities of train and test collections.
        # Initialize entities embedding

        # size of window which includes relations and also filters them.
        # len([ ... entity_1 ... entity_2 ...]) = window_size_in_words
        # TODO. window size should be unchanged

        all_indices = train_indices + test_indices
        entities_collections = [
            EntityCollection.from_file(self.io.get_entity_filepath(n))
            for n in all_indices
        ]
        entity_indices = EntityIndices(entities_collections)

        # Train collection
        train_news_words_collection, train_relations_collection = self._process_into_collections(
            train_indices, entity_indices, word_embedding, words_per_news,
            True)

        # Test collection
        test_news_words_collection, test_relations_collection = self._process_into_collections(
            test_indices, entity_indices, word_embedding, words_per_news,
            False)

        words_per_news = min(
            train_news_words_collection.get_min_words_per_news_count(),
            test_news_words_collection.get_min_words_per_news_count(),
            words_per_news)

        self.train_relations_collection = train_relations_collection
        self.test_relations_collection = test_relations_collection
        self.test_news_words_collection = test_news_words_collection

        train_bags_collection = BagsCollection(
            train_relations_collection.relations, bag_size)
        test_bags_collection = BagsCollection(
            test_relations_collection.relations, bag_size)

        train_bags_collection.shuffle()
        test_bags_collection.shuffle()

        self.test_minibatches = test_bags_collection.to_minibatches(
            bags_per_minibatch)
        self.train_minibatches = train_bags_collection.to_minibatches(
            bags_per_minibatch)
        self.train_news_words_collection = train_news_words_collection

        self.E = train_news_words_collection.get_embedding_matrix(
        )  # test collection has the same matrix.

        self.network = None
        self.callback = callback
예제 #6
0
 def iter_all_entity_collections():
     all_indices = io.get_data_indices(DataType.Train) + \
                   io.get_data_indices(DataType.Test)
     for news_index in all_indices:
         yield EntityCollection.from_file(self.io.get_entity_filepath(news_index),
                                          self.Settings.Stemmer)