def PreProcessedQueries(wikipedia_dump_fname, vectors=wordvectors, queries=queries, redirects=page_redirects, surface=surface_counts):

    get_words = re.compile('[^a-zA-Z0-9 ]')
    get_link = re.compile('.*?\[(.*?)\].*?')

    wordvec = WordTokenizer(vectors, sentence_length=200)
    documentvec = WordTokenizer(vectors, sentence_length=1)

    queried_pages = set()
    for docs, q in queries.iteritems():
        wordvec.tokenize(docs)
        for sur, v in q.iteritems():
            wrds_sur = get_words.sub(' ', sur)
            wordvec.tokenize(wrds_sur)
            link_sur = get_link.match(sur).group(1)
            wordvec.tokenize(link_sur)
            for link in v['vals'].keys():
                wrds = get_words.sub(' ', link)
                wordvec.tokenize(wrds)
                tt = WikiRegexes.convertToTitle(link)
                documentvec.get_location(tt)
                queried_pages.add(tt)

    added_pages = set()
    for title in queried_pages:
        if title in redirects:
            #wordvec.tokenize(self.redirects[title])
            documentvec.get_location(redirects[title])
            added_pages.add(redirects[title])
    queried_pages |= added_pages

    page_content = {}

#     class GetWikipediaWords(WikipediaReader, WikiRegexes):

#         def readPage(ss, title, content):
#             tt = ss.convertToTitle(title)
#             if tt in queried_pages:
#                 cnt = ss._wikiToText(content)
#                 page_content[tt] = wordvec.tokenize(cnt)

#     GetWikipediaWords(wikipedia_dump_fname).read()

    rr = redirects
    rq = queried_pages
    rc = page_content
    rs = surface

    class PreProcessedQueriesCls(object):

        wordvecs = wordvec
        documentvecs = documentvec
        queries = queries
        redirects = rr
        queried_pages = rq
        page_content = rc
        surface_counts = rs


    return PreProcessedQueriesCls
Beispiel #2
0
    def create_entity_description_dict(self):
        if os.path.exists(self.dict_path + "entity_description_dict.pkl"):
            print("Entity description dictionary has existed!")
            return
        else:
            self.get_all_entities()
            re_pattern = re.compile('[^a-zA-Z0-9_ ]')
            print("Start Saving!")
            page_content = {}
            all_entity_title = [
                WikiRegexes.convertToTitle(a) for a in self.all_entity
            ]

            class GetWikipediaWords(WikipediaReader, WikiRegexes):
                def readPage(ss, title, content, namespace):
                    if namespace != 0:
                        return
                    tt = ss.convertToTitle(title)
                    if tt in all_entity_title:
                        ctx = ss._wikiToText(content)
                        only_doc = re_pattern.sub('', ctx)
                        page_content[tt] = only_doc
                    #not recommend, only for convience
                    if len(page_content.keys()) == len(self.all_entity):
                        print("Find all!")
                        DictOperator().save_dict(self.dict_path, "test.pkl",
                                                 page_content)

            GetWikipediaWords(self.wiki_dump_file).read()

            no_description_entity_num = 0
            for entity in self.all_entity:
                try:
                    self.entity_description_dict[entity] = page_content[
                        WikiRegexes.convertToTitle(entity)]
                except KeyError:
                    self.entity_description_dict[entity] = '<UNK>'
                    no_description_entity_num += 1
            print(
                "{} entities cannot find their description in the wiki_dump file\n"
                .format(no_description_entity_num))
            print("Sucessfully create entity description dictionary!\n")
            DictOperator().save_dict(self.dict_path,
                                     "entity_description_dict.pkl",
                                     self.entity_description_dict)
def PreProcessedQueries(wikipedia_dump_fname,
                        wordvec=wordvectors,
                        queries=queries,
                        redirects=page_redirects):

    get_words = re.compile('[^a-zA-Z0-9 ]')
    get_link = re.compile('.*?\[(.*?)\].*?')

    queried_pages = set()
    for docs, q in queries.iteritems():
        wordvec.tokenize(docs, length=200)
        for sur, v in q.iteritems():
            wrds_sur = get_words.sub(' ', sur)
            wordvec.tokenize(wrds_sur)
            link_sur = get_link.match(sur).group(1)
            wordvec.tokenize(link_sur)
            for link in v['vals'].keys():
                wrds = get_words.sub(' ', link)
                wordvec.tokenize(wrds)
                tt = WikiRegexes.convertToTitle(link)
                wordvec.get_location(tt)
                queried_pages.add(tt)

    added_pages = set()
    for title in queried_pages:
        if title in redirects:
            #wordvec.tokenize(self.redirects[title])
            added_pages.add(redirects[title])
    queried_pages |= added_pages

    page_content = {}

    #     class GetWikipediaWords(WikipediaReader, WikiRegexes):

    #         def readPage(ss, title, content):
    #             tt = ss.convertToTitle(title)
    #             if tt in queried_pages:
    #                 cnt = ss._wikiToText(content)
    #                 page_content[tt] = wordvec.tokenize(cnt)

    #     GetWikipediaWords(wikipedia_dump_fname).read()

    rr = redirects
    rq = queried_pages
    rc = page_content

    class PreProcessedQueriesCls(object):

        wordvecs = wordvec
        queries = queries
        redirects = rr
        queried_pages = rq
        page_content = rc

    return PreProcessedQueriesCls
    def compute_batch(self,
                      isTraining=True,
                      useTrainingFunc=True,
                      batch_run_func=None):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        if batch_run_func is None:
            batch_run_func = self.run_batch
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        get_words = re.compile('[^a-zA-Z0-9 ]')
        get_link = re.compile('.*?\[(.*?)\].*?')

        empty_sentence = np.zeros(self.sentence_length, dtype='int32')

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]['training'] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(
                self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(
                    self.wordvecs.tokenize(get_words.sub(' ', surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(
                    self.wordvecs.tokenize(surlink,
                                           length=self.sentence_length_short))
                surmatch = surlink.lower()
                surcounts = self.surface_counts.get(surmatch)
                if not surcounts:
                    self.failed_match.append(surmatch)
                    surcounts = {}
                target_body_words_input = []  # words from the target document
                target_words_input = []  # the words from the target title
                target_matches_surface = []
                target_inputs = []  # the target vector
                target_learings = []
                target_match_counts = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                #                 target_feat_indicators = []

                denotations_joint_indicators = []
                denotations_linked_query = []
                denotations_range = []

                denotation_target_linked = []

                target_isgold = []

                queries_feats_indicators = []
                for ind in targets['query_vals']:
                    query_feats = np.zeros((self.num_indicator_features, ),
                                           dtype='int8')
                    #                     query_feats[ind] = 1
                    queries_feats_indicators.append(query_feats)
                queries_len = len(targets['query_vals'])

                for target in set(targets['vals'].keys() + random.sample(
                        self.documentvecs.reverse_word_location,
                        self.num_negative_target_samples)) - {
                            None,
                        }:
                    isGold = target in targets['gold']
                    wiki_title = WikiRegexes.convertToTitle(target)
                    cnt_wrds = self.page_content.get(
                        wiki_title)  #WikiRegexes.convertToTitle(target))
                    cnt = self.documentvecs.get_location(wiki_title)
                    if wiki_title == 'nil':
                        cnt = 0  # this is the stop symbol location
                    if cnt is None:
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        if cnt_wrds is None:
                            # really know nothing
                            continue
                        else:
                            # we must not have had enough links to this document
                            # but still have the target text
                            cnt = 0
                    if isGold:
                        target_gold_loc = len(target_inputs)
                        target_isgold.append(1)
                    else:
                        target_isgold.append(0)
                    target_body_words_input.append(
                        cnt_wrds if cnt_wrds is not None else empty_sentence)
                    target_words_input.append(
                        self.wordvecs.tokenize(
                            get_words.sub(' ', target),
                            length=self.sentence_length_short))
                    target_inputs.append(cnt)
                    # page_content already tokenized
                    target_matches_surface.append(
                        int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                    target_match_counts.append(surcounts.get(wiki_title, 0))

                    joint_indicators = []
                    query_idx = []
                    indicators_place = targets['vals'].get(target)
                    if indicators_place:
                        # [queries][indicator id]
                        for indx in xrange(len(indicators_place[1])):
                            local_feats = np.zeros(
                                (self.num_indicator_features, ), dtype='int8')
                            local_feats[indicators_place[1][indx]] = 1
                            local_feats[targets['query_vals']
                                        [indx]] = 1  # features from the joint
                            #                             if isGold:  #################################### hack
                            #                                 local_feats[-1] = 1
                            joint_indicators.append(local_feats)
                            query_idx.append(len(self.current_queries) + indx)
                    else:
                        raise NotImplementedError()
                        for indx in xrange(queries_len):
                            local_feats = np.zeros(
                                (self.num_indicator_features, ), dtype='int8')
                            local_feats[self.impossible_query] = 1
                            joint_indicators.append(local_feats)
                            query_idx.append(len(self.current_queries) + indx)

                    start_range = len(denotations_joint_indicators) + len(
                        self.current_denotations_feats_indicators)
                    denotations_joint_indicators += joint_indicators
                    denotations_linked_query += query_idx
                    denotations_range.append(
                        [start_range, start_range + len(joint_indicators)])
                    denotation_target_linked += [
                        len(self.current_target_words) +
                        len(target_words_input) - 1
                    ] * len(query_idx)


#                     indicators = np.zeros((self.num_indicator_features,), dtype='int8')
#                     if indicators_place:

#                         indicators[indicators_place[1]] = 1
#                     target_feat_indicators.append(indicators)

#if wiki_title not in surcounts:
#    print surcounts, wiki_title
                if target_gold_loc is not None or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    #                     gold_loc = (len(self.current_target_goal) + target_gold_loc)
                    sorted_match_counts = [-4, -3, -2, -1] + sorted(
                        set(target_match_counts))
                    #print sorted_match_counts
                    target_match_counts_indicators = [[
                        int(s == sorted_match_counts[-1]),
                        int(s == sorted_match_counts[-2]),
                        int(s == sorted_match_counts[-3]),
                        int(0 < s <= sorted_match_counts[-4]),
                        int(s == 0),
                    ] for s in target_match_counts]
                    #                     self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input
                    self.current_target_matches_surface += target_matches_surface
                    self.current_surface_target_counts += target_match_counts_indicators
                    self.current_target_body_words += target_body_words_input
                    #                     self.current_feat_indicators += target_feat_indicators
                    self.current_target_is_gold += target_isgold

                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append([
                        target_group_start,
                        target_group_end,
                        -1  # gold_loc
                    ])
                    #self.current_boosted_groups.append(targets['boosted'])

                    self.current_queries += queries_feats_indicators

                    self.current_denotations_feats_indicators += denotations_joint_indicators
                    self.current_denotations_related_query += denotations_linked_query
                    self.current_denotations_range += denotations_range

                    self.current_denotation_targets_linked += denotation_target_linked

                #self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                #                 return
                batch_run_func(func)
                sys.stderr.write('%i\r' % self.total_links)
                if self.total_links > self.num_training_items:
                    return self.total_loss / self.total_links,  #self.total_boosted_loss / self.total_links

        if len(self.current_target_id) > 0:
            batch_run_func(func)
            #self.run_batch(func)

        return self.total_loss / self.total_links,  #self.total_boosted_loss / self.total_links
    def compute_batch(self, isTraining=True, useTrainingFunc=True, batch_run_func=None):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        if batch_run_func is None:
            batch_run_func = self.run_batch
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        get_words = re.compile("[^a-zA-Z0-9 ]")
        get_link = re.compile(".*?\[(.*?)\].*?")

        empty_sentence = np.zeros(self.sentence_length, dtype="int32")

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]["training"] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(self.wordvecs.tokenize(get_words.sub(" ", surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(self.wordvecs.tokenize(surlink, length=self.sentence_length_short))
                surmatch = surlink.lower()
                surcounts = self.surface_counts.get(surmatch)
                if not surcounts:
                    self.failed_match.append(surmatch)
                    surcounts = {}
                target_body_words_input = []  # words from the target document
                target_words_input = []  # the words from the target title
                target_matches_surface = []
                target_inputs = []  # the target vector
                target_learings = []
                target_match_counts = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                #                 target_feat_indicators = []

                denotations_joint_indicators = []
                denotations_linked_query = []
                denotations_range = []

                denotation_target_linked = []

                target_isgold = []

                queries_feats_indicators = []
                for ind in targets["query_vals"]:
                    query_feats = np.zeros((self.num_indicator_features,), dtype="int8")
                    #                     query_feats[ind] = 1
                    queries_feats_indicators.append(query_feats)
                queries_len = len(targets["query_vals"])

                for target in set(
                    targets["vals"].keys()
                    + random.sample(self.documentvecs.reverse_word_location, self.num_negative_target_samples)
                ) - {None}:
                    isGold = target in targets["gold"]
                    wiki_title = WikiRegexes.convertToTitle(target)
                    cnt_wrds = self.page_content.get(wiki_title)  # WikiRegexes.convertToTitle(target))
                    cnt = self.documentvecs.get_location(wiki_title)
                    if wiki_title == "nil":
                        cnt = 0  # this is the stop symbol location
                    if cnt is None:
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        if cnt_wrds is None:
                            # really know nothing
                            continue
                        else:
                            # we must not have had enough links to this document
                            # but still have the target text
                            cnt = 0
                    if isGold:
                        target_gold_loc = len(target_inputs)
                        target_isgold.append(1)
                    else:
                        target_isgold.append(0)
                    target_body_words_input.append(cnt_wrds if cnt_wrds is not None else empty_sentence)
                    target_words_input.append(
                        self.wordvecs.tokenize(get_words.sub(" ", target), length=self.sentence_length_short)
                    )
                    target_inputs.append(cnt)
                    # page_content already tokenized
                    target_matches_surface.append(int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                    target_match_counts.append(surcounts.get(wiki_title, 0))

                    joint_indicators = []
                    query_idx = []
                    indicators_place = targets["vals"].get(target)
                    if indicators_place:
                        # [queries][indicator id]
                        for indx in xrange(len(indicators_place[1])):
                            local_feats = np.zeros((self.num_indicator_features,), dtype="int8")
                            local_feats[indicators_place[1][indx]] = 1
                            local_feats[targets["query_vals"][indx]] = 1  # features from the joint
                            #                             if isGold:  #################################### hack
                            #                                 local_feats[-1] = 1
                            joint_indicators.append(local_feats)
                            query_idx.append(len(self.current_queries) + indx)
                    else:
                        raise NotImplementedError()
                        for indx in xrange(queries_len):
                            local_feats = np.zeros((self.num_indicator_features,), dtype="int8")
                            local_feats[self.impossible_query] = 1
                            joint_indicators.append(local_feats)
                            query_idx.append(len(self.current_queries) + indx)

                    start_range = len(denotations_joint_indicators) + len(self.current_denotations_feats_indicators)
                    denotations_joint_indicators += joint_indicators
                    denotations_linked_query += query_idx
                    denotations_range.append([start_range, start_range + len(joint_indicators)])
                    denotation_target_linked += [len(self.current_target_words) + len(target_words_input) - 1] * len(
                        query_idx
                    )

                #                     indicators = np.zeros((self.num_indicator_features,), dtype='int8')
                #                     if indicators_place:

                #                         indicators[indicators_place[1]] = 1
                #                     target_feat_indicators.append(indicators)

                # if wiki_title not in surcounts:
                #    print surcounts, wiki_title
                if target_gold_loc is not None or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    #                     gold_loc = (len(self.current_target_goal) + target_gold_loc)
                    sorted_match_counts = [-4, -3, -2, -1] + sorted(set(target_match_counts))
                    # print sorted_match_counts
                    target_match_counts_indicators = [
                        [
                            int(s == sorted_match_counts[-1]),
                            int(s == sorted_match_counts[-2]),
                            int(s == sorted_match_counts[-3]),
                            int(0 < s <= sorted_match_counts[-4]),
                            int(s == 0),
                        ]
                        for s in target_match_counts
                    ]
                    #                     self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input
                    self.current_target_matches_surface += target_matches_surface
                    self.current_surface_target_counts += target_match_counts_indicators
                    self.current_target_body_words += target_body_words_input
                    #                     self.current_feat_indicators += target_feat_indicators
                    self.current_target_is_gold += target_isgold

                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append([target_group_start, target_group_end, -1])  # gold_loc
                    # self.current_boosted_groups.append(targets['boosted'])

                    self.current_queries += queries_feats_indicators

                    self.current_denotations_feats_indicators += denotations_joint_indicators
                    self.current_denotations_related_query += denotations_linked_query
                    self.current_denotations_range += denotations_range

                    self.current_denotation_targets_linked += denotation_target_linked

                # self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                #                 return
                batch_run_func(func)
                sys.stderr.write("%i\r" % self.total_links)
                if self.total_links > self.num_training_items:
                    return (self.total_loss / self.total_links,)  # self.total_boosted_loss / self.total_links

        if len(self.current_target_id) > 0:
            batch_run_func(func)
            # self.run_batch(func)

        return (self.total_loss / self.total_links,)  # self.total_boosted_loss / self.total_links
Beispiel #6
0
def PreProcessedQueries(
        wikipedia_dump_fname,
        vectors,#=wordvectors,
        queries,#=queries,
        redirects,#,=page_redirects,
        surface,#=surface_counts
):

    get_words = re.compile('[^a-zA-Z0-9 ]')
    get_link = re.compile('.*?\[(.*?)\].*?')

    wordvec = WordTokenizer(vectors, sentence_length=200)
    documentvec = WordTokenizer(vectors, sentence_length=1)

    queried_pages = set()
    for docs, q in queries.iteritems():
        wordvec.tokenize(docs)
        for sur, v in q.iteritems():
            wrds_sur = get_words.sub(' ', sur)
            wordvec.tokenize(wrds_sur)
            link_sur = get_link.match(sur).group(1)
            wordvec.tokenize(link_sur)
            for link in v['vals'].keys():
                wrds = get_words.sub(' ', link)
                wordvec.tokenize(wrds)
                tt = WikiRegexes.convertToTitle(link)
                documentvec.get_location(tt)
                queried_pages.add(tt)


    added_pages = set()
    for title in queried_pages:
        if title in redirects:
            #wordvec.tokenize(self.redirects[title])
            documentvec.get_location(redirects[title])
            added_pages.add(redirects[title])
    queried_pages |= added_pages

    for w in queried_pages:
        wordvec.tokenize(get_words.sub(' ', w))

    page_content = {}

    class GetWikipediaWords(WikipediaReader, WikiRegexes):

        def readPage(ss, title, content, namespace):
            if namespace != 0:
                return
            tt = ss.convertToTitle(title)
            if tt in queried_pages:
                cnt = ss._wikiToText(content)
                page_content[tt] = wordvec.tokenize(cnt)

    GetWikipediaWords(wikipedia_dump_fname).read()

    rr = redirects
    rq = queried_pages
    rc = page_content
    rs = surface

    qp = queried_pages
    qq = queries

    class PreProcessedQueriesCls(object):

        wordvecs = wordvec
        documentvecs = documentvec
        queries = qq
        redirects = rr
        queried_pages = rq
        page_content = rc
        surface_counts = rs
        queried_pages = qp


    return PreProcessedQueriesCls
    def compute_batch(self, isTraining=True, useTrainingFunc=True):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        get_words = re.compile('[^a-zA-Z0-9 ]')
        get_link = re.compile('.*?\[(.*?)\].*?')

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]['training'] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(
                self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(
                    self.wordvecs.tokenize(get_words.sub(' ', surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(
                    self.wordvecs.tokenize(surlink))
                surmatch = surlink.lower()
                #target_page_input = []
                target_words_input = []
                target_matches_surface = []
                target_inputs = []
                target_learings = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                for target in targets['vals'].keys():
                    # skip the items that we don't know the gold for
                    if not targets['gold'] and isTraining:
                        continue
                    isGold = target == targets['gold']
                    #cnt = self.page_content.get(WikiRegexes.convertToTitle(target))
                    cnt = self.wordvecs.get_location(
                        WikiRegexes.convertToTitle(target))
                    if cnt is None:
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        continue
                    if isGold:
                        target_gold_loc = len(target_inputs)
                    #target_page_input.append(cnt)
                    target_words_input.append(
                        self.wordvecs.tokenize(get_words.sub(' ', target)))
                    target_inputs.append(cnt)  # page_content already tokenized
                    target_matches_surface.append(
                        int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                if target_gold_loc is not None or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    gold_loc = (len(self.current_target_goal) +
                                target_gold_loc)
                    self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input  # TODO: add
                    self.current_target_matches_surface += target_matches_surface
                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append(
                        [target_group_start, target_group_end, gold_loc])

                #self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                #return
                self.run_batch(func)
                if self.total_links > self.num_training_items:
                    return self.total_loss / self.total_links

        if len(self.current_target_id) > 0:
            self.run_batch(func)

        return self.total_loss / self.total_links
Beispiel #8
0
    def compute_batch(self, isTraining=True, useTrainingFunc=True):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        self.failed_match = []
        self.failed_page_match = []

        get_words = re.compile('[^a-zA-Z0-9 ]')
        get_link = re.compile('.*?\[(.*?)\].*?')

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]['training'] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(
                self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(
                    self.wordvecs.tokenize(get_words.sub(' ', surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(
                    self.wordvecs.tokenize(surlink))
                surmatch = surlink.lower()
                surcounts = self.surface_counts.get(surmatch)
                if not surcounts:
                    self.failed_match.append(surmatch)
                    surcounts = {}
                #target_page_input = []
                target_words_input = []
                target_matches_surface = []
                target_inputs = []
                target_learings = []
                target_match_counts = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                for target in set(targets['vals'].keys() + random.sample(
                        self.documentvecs.reverse_word_location, 3)) - {
                            None,
                        }:
                    # skip the items that we don't know the gold for
                    if not targets['gold'] and isTraining:
                        continue
                    isGold = target == targets['gold']
                    #cnt = self.page_content.get(WikiRegexes.convertToTitle(target))
                    wiki_title = WikiRegexes.convertToTitle(target)
                    cnt = self.documentvecs.get_location(wiki_title)
                    if cnt is None:  #or wiki_title == 'nil':
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        #
                        # there are also nil queries that are generated for every document
                        # but we actually have a nil page that is getting referenced
                        # so just filter it out for now
                        continue
                    if isGold:
                        target_gold_loc = len(target_inputs)
                    #target_page_input.append(cnt)
                    target_words_input.append(
                        self.wordvecs.tokenize(get_words.sub(' ', target)))
                    target_inputs.append(cnt)  # page_content already tokenized
                    target_matches_surface.append(
                        int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                    tmc = surcounts.get(wiki_title, 0)
                    if tmc is 0:
                        self.failed_page_match.append((surcounts, wiki_title))
                    target_match_counts.append(tmc)
                    #if wiki_title not in surcounts:
                    #    print surcounts, wiki_title
                if target_gold_loc is not -1 or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    gold_loc = (len(self.current_target_goal) +
                                target_gold_loc)
                    sorted_match_counts = [-4, -3, -2, -1] + sorted(
                        set(target_match_counts))
                    #print sorted_match_counts
                    target_match_counts_indicators = [[
                        int(s == sorted_match_counts[-1]),
                        int(s == sorted_match_counts[-2]),
                        int(s == sorted_match_counts[-3]),
                        int(s <= sorted_match_counts[-4]),
                    ] for s in target_match_counts]
                    self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input  # TODO: add
                    self.current_target_matches_surface += target_matches_surface
                    self.current_surface_target_counts += target_match_counts_indicators
                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append(
                        [target_group_start, target_group_end, gold_loc])

                #self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                self.run_batch(func)
                if self.total_links > self.num_training_items:
                    return self.total_loss / self.total_links

        if len(self.current_target_id) > 0:
            self.run_batch(func)

        return self.total_loss / self.total_links
    def compute_batch(self, isTraining=True, useTrainingFunc=True):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        get_words = re.compile('[^a-zA-Z0-9 ]')
        get_link = re.compile('.*?\[(.*?)\].*?')

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]['training'] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(self.wordvecs.tokenize(get_words.sub(' ' , surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(self.wordvecs.tokenize(surlink))
                surmatch = surlink.lower()
                #target_page_input = []
                target_words_input = []
                target_matches_surface = []
                target_inputs = []
                target_learings = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                for target in targets['vals'].keys():
                    # skip the items that we don't know the gold for
                    if not targets['gold'] and isTraining:
                        continue
                    isGold = target == targets['gold']
                    #cnt = self.page_content.get(WikiRegexes.convertToTitle(target))
                    cnt = self.wordvecs.get_location(WikiRegexes.convertToTitle(target))
                    if cnt is None:
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        continue
                    if isGold:
                        target_gold_loc = len(target_inputs)
                    #target_page_input.append(cnt)
                    target_words_input.append(self.wordvecs.tokenize(get_words.sub(' ', target)))
                    target_inputs.append(cnt)  # page_content already tokenized
                    target_matches_surface.append(int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                if target_gold_loc is not None or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    gold_loc = (len(self.current_target_goal) + target_gold_loc)
                    self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input   # TODO: add
                    self.current_target_matches_surface += target_matches_surface
                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append(
                        [target_group_start, target_group_end,
                         gold_loc])

                #self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                #return
                self.run_batch(func)
                if self.total_links > self.num_training_items:
                    return self.total_loss / self.total_links

        if len(self.current_target_id) > 0:
            self.run_batch(func)

        return self.total_loss / self.total_links
    def compute_batch(self, isTraining=True, useTrainingFunc=True):
        if isTraining and useTrainingFunc:
            func = self.train_func
        else:
            func = self.test_func
        self.reset_accums()
        self.total_links = 0
        self.total_loss = 0.0

        self.failed_match = []
        self.failed_page_match = []

        get_words = re.compile('[^a-zA-Z0-9 ]')
        get_link = re.compile('.*?\[(.*?)\].*?')

        for doc, queries in self.queries.iteritems():
            # skip the testing documents while training and vice versa
            if queries.values()[0]['training'] != isTraining:
                continue
            docid = len(self.current_documents)
            self.current_documents.append(self.wordvecs.tokenize(doc, length=self.document_length))
            for surtxt, targets in queries.iteritems():
                self.current_link_id.append(docid)
                surid = len(self.current_surface_link)
                self.current_surface_context.append(self.wordvecs.tokenize(get_words.sub(' ' , surtxt)))
                surlink = get_link.match(surtxt).group(1)
                self.current_surface_link.append(self.wordvecs.tokenize(surlink))
                surmatch = surlink.lower()
                surcounts = self.surface_counts.get(surmatch)
                if not surcounts:
                    self.failed_match.append(surmatch)
                    surcounts = {}
                #target_page_input = []
                target_words_input = []
                target_matches_surface = []
                target_inputs = []
                target_learings = []
                target_match_counts = []
                target_gold_loc = -1
                target_group_start = len(self.current_target_input)
                for target in set(targets['vals'].keys() + random.sample(self.documentvecs.reverse_word_location, 3)) - {None,}:
                    # skip the items that we don't know the gold for
                    if not targets['gold'] and isTraining:
                        continue
                    isGold = target == targets['gold']
                    #cnt = self.page_content.get(WikiRegexes.convertToTitle(target))
                    wiki_title = WikiRegexes.convertToTitle(target)
                    cnt = self.documentvecs.get_location(wiki_title)
                    if cnt is None :#or wiki_title == 'nil':
                        # were not able to find this wikipedia document
                        # so just ignore tihs result since trying to train on it will cause
                        # issues
                        #
                        # there are also nil queries that are generated for every document
                        # but we actually have a nil page that is getting referenced
                        # so just filter it out for now
                        continue
                    if isGold:
                        target_gold_loc = len(target_inputs)
                    #target_page_input.append(cnt)
                    target_words_input.append(self.wordvecs.tokenize(get_words.sub(' ', target)))
                    target_inputs.append(cnt)  # page_content already tokenized
                    target_matches_surface.append(int(surmatch == target.lower()))
                    target_learings.append((targets, target))
                    tmc = surcounts.get(wiki_title, 0)
                    if tmc is 0:
                        self.failed_page_match.append((surcounts, wiki_title))
                    target_match_counts.append(tmc)
                    #if wiki_title not in surcounts:
                    #    print surcounts, wiki_title
                if target_gold_loc is not -1 or not isTraining:  # if we can't get the gold item
                    # contain the index of the gold item for these items, so it can be less then it
                    gold_loc = (len(self.current_target_goal) + target_gold_loc)
                    sorted_match_counts = [-4,-3,-2,-1] + sorted(set(target_match_counts))
                    #print sorted_match_counts
                    target_match_counts_indicators = [
                        [
                            int(s == sorted_match_counts[-1]),
                            int(s == sorted_match_counts[-2]),
                            int(s == sorted_match_counts[-3]),
                            int(s <= sorted_match_counts[-4]),
                        ]
                        for s in target_match_counts
                    ]
                    self.current_target_goal += [gold_loc] * len(target_inputs)
                    self.current_target_input += target_inputs
                    self.current_target_id += [surid] * len(target_inputs)
                    self.current_target_words += target_words_input   # TODO: add
                    self.current_target_matches_surface += target_matches_surface
                    self.current_surface_target_counts += target_match_counts_indicators
                    target_group_end = len(self.current_target_input)
                    self.current_learning_groups.append(
                        [target_group_start, target_group_end,
                         gold_loc])

                #self.current_target_goal.append(isGold)
                self.learning_targets += target_learings
            if len(self.current_target_id) > self.batch_size:
                self.run_batch(func)
                if self.total_links > self.num_training_items:
                    return self.total_loss / self.total_links

        if len(self.current_target_id) > 0:
            self.run_batch(func)

        return self.total_loss / self.total_links
Beispiel #11
0
    def process_tsv_file(self):
        with open(self.tsv_file, "r", encoding="utf-8") as f:
            line = f.readline()
            query_num = 0
            no_dscpt = 0
            total_entity = 0
            while line:
                doc_id, mention, local_ctx, gold_entity, train_state = line.strip(
                ).split("\t")

                self.query_dict[query_num] = self.init_query_info_dict(0)
                self.original_query_dict[
                    query_num] = self.init_query_info_dict(1)

                self.query_dict[query_num]['train_state'] = int(train_state)
                self.original_query_dict[query_num]['train_state'] = int(
                    train_state)

                self.query_dict[query_num]['doc_id'] = str(doc_id)
                self.original_query_dict[query_num]['doc_id'] = str(doc_id)

                mention_wordid, mention_mask = self.get_word_id(
                    mention, self.config.mention_len, False)
                self.query_dict[query_num]['mention_wordid'].append(
                    mention_wordid)
                self.query_dict[query_num]['mention_mask'].append(mention_mask)
                self.original_query_dict[query_num]['mention'] = mention

                only_local_ctx = self.re_pattern.sub(' ', local_ctx)
                local_ctx_wordid, local_ctx_mask = self.get_word_id(
                    only_local_ctx, self.config.local_ctx_len, True)
                self.query_dict[query_num]['local_ctx_wordid'].append(
                    local_ctx_wordid)
                self.query_dict[query_num]['local_ctx_mask'].append(
                    local_ctx_mask)
                self.original_query_dict[query_num]['local_ctx'] = local_ctx

                self.query_dict[query_num]['gold_entity'].append(gold_entity)
                self.original_query_dict[query_num]['gold_entity'].append(
                    gold_entity)

                if self.process_state == 1 and self.raw_data_format == 'json' or self.process_state == 0 and self.raw_data_format == 'use_json':
                    try:
                        candidate_entity_list = self.alia_entity_dict_list[int(
                            train_state)][mention]
                    except KeyError:
                        candidate_entity_list = []
                else:
                    try:
                        candidate_entity_list = self.alia_entity_dict[mention]
                    except KeyError:
                        candidate_entity_list = []
                addition_entity = ['NIL']
                # if str(train_state) == '0':
                # 	if len(candidate_entity_list) < 10 and len(candidate_entity_list) > 0:
                # 		addition_entity += random.sample(self.all_entity, 10 - len(candidate_entity_list))
                #for entity in candidate_entity_list + ['-NIL-']:
                for entity in candidate_entity_list + addition_entity:
                    self.query_dict[query_num]['candidate_entity'].append(
                        entity)
                    self.original_query_dict[query_num][
                        'candidate_entity'].append(entity)

                    wiki_entity_title = WikiRegexes.convertToTitle(entity)
                    #wiki_entity_title = entity
                    wiki_entity_title_wordid, wiki_entity_title_mask = self.get_word_id(
                        wiki_entity_title, self.config.wiki_title_len, False)
                    self.query_dict[query_num][
                        'candidate_entity_title_wordid'].append(
                            wiki_entity_title_wordid)
                    self.query_dict[query_num]['entity_title_mask'].append(
                        wiki_entity_title_mask)
                    self.original_query_dict[query_num][
                        'candidate_entity_title'].append(
                            [entity, wiki_entity_title])

                    total_entity += 1
                    if entity not in self.entity_description_dict.keys():
                        no_dscpt += 1
                        description = "<UNK>"
                    else:
                        description = self.entity_description_dict[entity]
                    #shorten the description to save process time
                    description = description[:min(
                        len(description), self.config.wiki_doc_len * 15)]
                    entity_description_wordid, description_mask = self.get_word_id(
                        description, self.config.wiki_doc_len, True)
                    self.query_dict[query_num][
                        'candidate_description_wordid'].append(
                            entity_description_wordid)
                    self.query_dict[query_num]['description_mask'].append(
                        description_mask)
                    self.original_query_dict[query_num][
                        'candidate_description'].append(description)
                    #self.query_dict[query_num]['description_mask'].append(self.get_conv_mask(description_mask))

                    self.query_dict[query_num]['candidate_score'].append(0)
                line = f.readline()
                query_num += 1

        DictOperator().save_dict(self.dict_path, self.final_query_file,
                                 self.query_dict)
        DictOperator().save_dict(
            self.dict_path,
            self.final_query_file.split(".")[0] + "_words.pkl",
            self.original_query_dict)
        print(no_dscpt, total_entity)