Ejemplo n.º 1
0
    def calc_v1_matrix(self, entity_id_list):
        # find links that are within the set of nodes we are passed, and
        # all those in bound to them, and out bound from them
        from_list = []
        to_list = []
        value_list = []
        max_id = 0
        wds = WikipediaDataset()
        for from_entity_id in entity_id_list:
            link_to = wds.get_links_to(from_entity_id)
            for v in link_to:
                from_list.append(from_entity_id)
                to_list.append(v)
                value_list.append(1)
                if v > max_id:
                    max_id = v

            link_from = set(wds.get_links_from(from_entity_id))
            for v in link_from:
                from_list.append(v)
                to_list.append(from_entity_id)
                value_list.append(1)
                if v > max_id:
                    max_id = v
        # TODO The following line threw a Value error (row index exceeds matrix dimentions) here on docid 579, and docid 105
        try:
            mtx = sparse.coo_matrix((value_list, (from_list, to_list)),
                                    shape=(max_id + 1, max_id + 1))
            pass
        except ValueError as e:
            self.logger.warning(
                'An error occurred returning None rather that a V1 matrix. %s',
                e)
            return None
        return mtx
Ejemplo n.º 2
0
Archivo: hack3.py Proyecto: dwanev/SEL
    def calc_v1_matrix(self, v0):

        from_list = []
        to_list = []
        value_list = []
        max_id = 0

        wds = WikipediaDataset()

        for from_entity_id in v0:
            link_to = wds.get_links_to(from_entity_id)
            for v in link_to:
                from_list.append(from_entity_id)
                to_list.append(v)
                value_list.append(1)
                if v > max_id:
                    max_id = v

            link_from = set(wds.get_links_to(from_entity_id))
            for v in link_from:
                from_list.append(v)
                to_list.append(from_entity_id)
                value_list.append(1)

        mtx = sparse.coo_matrix((value_list, (from_list, to_list)), shape=(max_id + 1, max_id + 1))

        full_set = set(to_list)
        full_set.update(from_list)

        return mtx, full_set
Ejemplo n.º 3
0
    def check_for_wikititle_collisions(self, case_insensitive=True):
        input_file = gzip.open("E:\\tmp\\" + 'wikipedia-dump.json.gz', 'rt', encoding='utf-8')

        wd = WikipediaDataset()
        wikititle_mt = wd.get_wikititle_case_insensitive_marisa_trie()
        wikititle_id_by_id = {}
        fname_prefix = self.get_intermediate_path()+'wikititle_id_by_id.'
        if case_insensitive:
            fname_prefix = fname_prefix + 'case_insensitive.'

        count = 1
        collision_count = 1
        line = ''

        duplicate_ids_by_wikititle = {}

        while count < 25000000 and line is not None:  # TODO check termination and remove magic number
            log_progress = count < 50000 and count % 10000 == 0
            if log_progress:
                self.logger.info('starting gc ')
                gc.collect()  # have no real reason to think this is needed or will help the memory issue
                self.logger.info('%d lines processed', count)

            save_progress = count % 1000000 == 0 or count == 10
            if save_progress:
                self.logger.info('%d lines processed', count)
                wikititle_by_id_filename = fname_prefix + str(count) + '.pickle'
                self.logger.info('about to save to %s', wikititle_by_id_filename)
                with open(wikititle_by_id_filename, 'wb') as handle:
                    pickle.dump(wikititle_id_by_id, handle, protocol=pickle.HIGHEST_PROTOCOL)
                self.logger.info('written  %s', wikititle_by_id_filename)

            line = input_file.readline()
            if  line is not None and line != '':
                data = json.loads(line)
                # pprint.pprint(data)
                if case_insensitive:
                    wikititle = data['wikiTitle'].lower()
                else:
                    wikititle = data['wikiTitle']

                wt_id = wikititle_mt[wikititle]
                wid = data['wid']
                wikititle_id_by_id[wid] = wt_id

            else:
                break

            count += 1

        self.logger.info('%d lines processed', count)
        wikititle_by_id_filename = fname_prefix + str(count) + '.pickle'
        self.logger.info('about to save to %s', wikititle_by_id_filename)
        with open(wikititle_by_id_filename, 'wb') as handle:
            pickle.dump(wikititle_id_by_id, handle, protocol=pickle.HIGHEST_PROTOCOL)
        self.logger.info('written  %s', wikititle_by_id_filename)
Ejemplo n.º 4
0
    def __init__(self):
        # Set up logging
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(handler)
        self.logger.propagate = False
        self.logger.setLevel(logging.INFO)

        # instance variables
        self.wiki_ds = WikipediaDataset()
Ejemplo n.º 5
0
 def get_links_totally_within(self, entity_id_list):
     from_list = []
     to_list = []
     value_list = []
     v0_vertice_set = set(entity_id_list)
     wds = WikipediaDataset()
     for entity_id in v0_vertice_set:
         links_to = wds.get_links_to(entity_id)
         for link_to in links_to:
             if link_to in v0_vertice_set:
                 to_list.append(entity_id)
                 from_list.append(link_to)
                 value_list.append(1)
     return from_list, to_list, value_list
 def __init__(self):
     # set up logging
     handler = logging.StreamHandler()
     handler.setFormatter(
         logging.Formatter(
             '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
     self.logger = logging.getLogger(__name__)
     self.logger.addHandler(handler)
     self.logger.propagate = False
     self.logger.setLevel(logging.INFO)
     # set up instance variables
     wds = WikipediaDataset()
     self.intermediate_path = FileLocations.get_temp_path()
     self.spotlight_util = SpotlightUtil()
Ejemplo n.º 7
0
def unit_test_2():
    wds = WikipediaDataset()

    check_links(1563047,   7412236, wds) # steve_jobs
    check_links(16360692, 57564770, wds)
    check_links(2678997,  57564127, wds)
    check_links(37717778, 57563280, wds)
    check_links(43375967, 57563305, wds)
    check_links(46991680, 57563292, wds)
    check_links(51332113, 57564772, wds)
    check_links(52466986, 57563202, wds)
    check_links(52679129, 57563204, wds)
    check_links(57562759, 57565023, wds)
    check_links(57564483, 57564503, wds)
    check_links(57564520, 57564533, wds)
    check_links(57565377, 57565381, wds)
    check_links(57565437, 57565531, wds)
    check_links(603291,   57564623, wds)
    check_links(9422390,  57563903, wds)
Ejemplo n.º 8
0
def train_model():
    X, y, docid_array, entity_id_array = load_feature_matrix(
        feature_filename=INTERMEDIATE_PATH +
        'dexter_all_heavy_catted_8_7_2018.txt',
        feature_names=feature_names,
        entity_id_index=1,
        y_feature_index=2,
        first_feature_index=4,
        number_features_per_line=40,
        tmp_filename='/tmp/temp_conversion_file.txt')

    # train only on records we have a golden salience for
    fg = FilterGolden()
    logger.info('X Shape = %s', X.shape)
    logger.info('y Shape = %s', y.shape)

    dexter_dataset = DatasetDexter()
    wikipedia_dataset = WikipediaDataset()

    X2, y2, docid2, entityid2 = fg.get_only_golden_rows(
        X, y, docid_array, entity_id_array, dexter_dataset, wikipedia_dataset)

    logger.info('X2 Shape = %s', X2.shape)
    logger.info('y2 Shape = %s', y2.shape)

    wrapper = GBRTWrapper()
    gbrt = wrapper.train_model_no_split(X2, y2, n_estimators=40)
    logger.info('trained')
    # gbrt.save_model()

    # from https://shankarmsy.github.io/stories/gbrt-sklearn.html
    # One of the benefits of growing trees is that we can understand how important each of the features are
    print("Feature Importances")
    print(gbrt.feature_importances_)
    print()
    # Let's print the R-squared value for train/test. This explains how much of the variance in the data our model is
    # able to decipher.
    print("R-squared for Train: %.2f" % gbrt.score(X2, y2))
    # print ("R-squared for Test: %.2f" %gbrt.score(X_test, y_test) )
    # - See more at: https://shankarmsy.github.io/stories/gbrt-sklearn.html#sthash.JNZQbnph.dpuf
    return gbrt, X2, y2, docid2, entityid2
Ejemplo n.º 9
0
    def go(self, filename, feature_names, filter_only_golden):
        X, y, docid_array, entity_id_array = load_feature_matrix(feature_filename=filename,
                                                                 feature_names=feature_names,
                                                                 entity_id_index=1,
                                                                 y_feature_index=2,
                                                                 first_feature_index=4,
                                                                 number_features_per_line=len(feature_names) + 4,
                                                                 tmp_filename='/tmp/temp_conversion_file.txt'
                                                                 )

        # train only on records we have a golden salience for
        self.logger.info('__________________________',)
        self.logger.info('File %s', filename)
        self.logger.info('X Shape = %s', X.shape)
        self.logger.info('y Shape = %s', y.shape)

        if filter_only_golden:
            dexterDataset = DatasetDexter()
            wikipediaDataset = WikipediaDataset()
            fg = sellibrary.filter_only_golden.FilterGolden()
            X, y, docid_array, entity_id_array = fg.get_only_golden_rows(X, y, docid_array, entity_id_array, dexterDataset, wikipediaDataset)
            self.logger.info('After filtering only golden rows:')
            self.logger.info('X Shape = %s', X.shape)
            self.logger.info('y Shape = %s', y.shape)

        self.logger.info('y [1] %s', y[1:10])
        self.logger.info('y [1] %s', y[y > 0.0])

        y[y < 2.0] = 0
        y[y >= 2.0] = 1

        ig = self.information_gain_v2(X, y)
        self.logger.info('ig %s', ig)
        self.logger.info('ig shape %s', ig.shape)

        d = {}
        for i in range(len(feature_names)):
            d[feature_names[i]] = ig[i]

        self.sort_and_print(d)
        return d
Ejemplo n.º 10
0
import logging

from sellibrary.wiki.wikipedia_datasets import WikipediaDataset

# dense to sparse

# set up logging
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)

if __name__ == "__main__":
    ds = WikipediaDataset()
    curid = 399877

    in_degree = ds.get_entity_in_degree(curid)
    out_degree = ds.get_entity_out_degree(curid)
    degree = ds.get_entity_degree(curid)
    logger.info('degree %d', degree)
    logger.info('in degree %d', in_degree)
    logger.info('out degree %d', out_degree)

    assert(degree >= 54)
    assert(in_degree >= 9)
    assert(out_degree >= 45)
import logging

from sellibrary.wiki.wikipedia_datasets import WikipediaDataset

# set up logging
handler = logging.StreamHandler()
handler.setFormatter(
    logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)

if __name__ == "__main__":
    ds = WikipediaDataset()
    # this requires extract_curid_by_wikititle_trie to have been run first
    ds.extract_graph_from_compressed()
Ejemplo n.º 12
0
class SpotlightUtil:
    def __init__(self):
        # Set up logging
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(handler)
        self.logger.propagate = False
        self.logger.setLevel(logging.INFO)

        # instance variables
        self.wiki_ds = WikipediaDataset()

    def hit_spotlight_web_retun_text(self,doc_text, confidence):

        if doc_text.find(">") > -1 or doc_text.find(">") > -1:
            self.logger.info('contains html characters')
            pprint.pprint(doc_text)
            doc_text = doc_text.replace(">", " ")
            doc_text = doc_text.replace("<", " ")

        # for documentation see https://www.dbpedia-spotlight.org/api
        # confidence
        # http://api.dbpedia-spotlight.org/en/annotate?confidence=0.5&text= happy Hippo
        link = "http://api.dbpedia-spotlight.org/en/annotate?confidence="+str(confidence)+"&text="+doc_text

        for i in range(3): # max 3 attempts
            try:
                #headers = {'Accept': 'application/json'} # json would be easier to parse
                headers = {'Accept': 'text/html'}
                f = requests.get(link, headers=headers)
                if f.ok:
                    pprint.pprint(f.text)
                    i_start = f.text.find("<body>")
                    i_end = f.text.find("</body>")
                    if i_end > -1 and i_start > -1:
                        html = f.text[(i_start + 6):(i_end)]
                    else:
                        html = ""
                    self.logger.info("Spotlight returned "+str(len(html))+" bytes")
                    return html
                else:
                    self.logger.info('obtained error code %d %s',f.status_code, f.reason)
            except:
                self.logger.info('error', exc_info=True)
        raise EnvironmentError('Could not complete web request')



    def my_pp(self,text):
        text = text.replace("\n", " ")
        text = text.replace("<a","\na<")
        text = text.replace("</a>","\n</a>\n")
        return text


    def get_wid_from_link_text(self, link):
        wid = -1
        t = "href=\"http://dbpedia.org/resource/"
        if link.find(t) >= 0:
            link_key_start = link.find(t) + len(t)
            link_key_end = link.find("\"", link_key_start)
            link_key = link[link_key_start:link_key_end]
            # print('link key:' + link_key)
            t = self.wiki_ds.get_wikititle_case_insensitive_marisa_trie()
            if link_key.lower() in t:
                values = t[link_key.lower()]
                wid = values[0][0]
                # print('wid:' + str(wid))
                # print('loc:',start_char,'-',end_char)
        return wid



    def post_process_html(self, html):
        #
        # There is a chance that this routine changes the text as it extracts the links from it
        # therefore the text is returned as well.
        #
        text = html.strip()

        if text.startswith("<div>"):
            text = text[5:]
        if text.endswith("</div>"):
            text = text[:-6]
        spots = []

        while text.find("<a",0) >= 0:

            start_char = text.find("<a")

            end_char = text.find("</a>")
            if end_char == -1:
                break
            next_link_start = text.find("<a")
            link = '-'
            while next_link_start < end_char and next_link_start != -1 and end_char != -1 and link != '':
                # we have interleaved links
                end_of_interleaved_link = text.find(">") + 1
                link = text[next_link_start:end_of_interleaved_link]

                #find piece to remove
                end_of_piece = end_of_interleaved_link
                if text.find("<a",end_of_interleaved_link) != -1 and \
                    text.find("<a",end_of_interleaved_link) < text.find("</a>",end_of_interleaved_link):
                    end_of_piece = text.find("<a",end_of_interleaved_link)

                text = text[0:next_link_start] + text[end_of_piece:]
                full_link = text[next_link_start:end_of_piece]

                if full_link.find("</a>") > -1 and full_link.find("</a>") < full_link.find(">"):
                    # this is a 'normal' link
                    anchor_text = full_link[:full_link.find("</a>")]
                else:
                    anchor_text = full_link[full_link.find(">")+1:]
                # extract link key
                wid = self.get_wid_from_link_text(link)
                s = Spot(wid, start_char, start_char + len(anchor_text), anchor_text)
                spots.append(s)
                print(wid, start_char, start_char + len(anchor_text), anchor_text)
                next_link_start = text.find("<a")
                end_char = text.find("</a>")

            # Remove the </a>
            end_char = text.find("</a>")
            if end_char > -1:
                text = text[0:end_char] + text[end_char+4:]
        return text, spots

    def hit_spotlight_return_spot_list(self, text, confidence):
        body_html = self.hit_spotlight_web_retun_text(text, confidence)
        processed_text, spot_list = self.post_process_html(body_html)
        return processed_text, spot_list
Ejemplo n.º 13
0
class WikipediaSpotter:
    def __init__(self):
        # set up logging
        handler = logging.StreamHandler()
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(handler)
        self.logger.propagate = False
        self.logger.setLevel(logging.INFO)
        # set up instance variables
        self.wikiDS = WikipediaDataset()

    @staticmethod
    def find_nth(haystack, needle, n):
        start = haystack.find(needle)
        while start >= 0 and n > 1:
            start = haystack.find(needle, start + len(needle))
            n -= 1
        if start == -1:  # if not found return all
            start = len(haystack)
        return start

    def get_entity_candidates_using_wikititle(self, text):

        candidates = []
        wikititle_trie = self.wikiDS.get_wikititle_case_insensitive_marisa_trie(
        )

        # trie.items("fo")   give all this thing below this
        # trie["foo"] returns the key of this item
        # key in trie2 # returns true / false

        i = 0
        jump = 1
        while i < len(text):
            for num_words in range(1, 5):  # 4 word max
                nth = self.find_nth(text[i:], ' ', num_words)
                if num_words == 1:
                    jump = max(1, nth + 1)
                candidate = text[i:(i + nth)]
                candidate = candidate.lower()
                candidate = candidate.replace('.', ' ')
                candidate = candidate.strip()
                candidate = candidate.replace(' ', '_')
                if candidate in wikititle_trie:
                    t = wikititle_trie[candidate]
                    cid = t[0][0]
                    value_list = [i, (i + nth), candidate, cid]
                    candidates.append(value_list)
                    self.logger.info(value_list)
            i += jump

    def get_entity_candidates(self, text):

        candidates = []
        text_trie = self.wikiDS.get_anchor_text_case_insensitive_marisa_trie()

        # trie.items("fo")   give all this thing below this
        # trie["foo"] returns the key of this item
        # key in trie2 # returns true / false

        i = 0
        jump = 1
        while i < len(text):
            for num_words in range(1, 5):  # 4 word max
                nth = self.find_nth(text[i:], ' ', num_words)
                if num_words == 1:
                    jump = max(1, nth + 1)
                candidate = text[i:(i + nth)]
                candidate = candidate.lower()
                candidate = candidate.strip()
                if candidate in text_trie:
                    t = text_trie[candidate]
                    cid = t[0][0]
                    s = Spot(cid, i, (i + nth), text[i:(i + nth)])
                    candidates.append(s)

                if len(candidate) > 1 and candidate[
                        -1] == '.':  # special case remove trailing full stop
                    candidate = candidate[0:-1]
                    if candidate in text_trie:
                        t = text_trie[candidate]
                        cid = t[0][0]
                        s = Spot(cid, i, (i + nth - 1), text[i:(i + nth - 1)])
                        candidates.append(s)

                if i + jump >= len(text):
                    break
            i += jump

        return candidates
Ejemplo n.º 14
0
        self.logger.info('processing complete')

    # def train_and_save_model(self, filename):
    #     spotter = SpotlightCachingSpotter(False)
    #     afinn_filename = '../sellibrary/resources/AFINN-111.txt'
    #     sentiment_processor = SentimentProcessor()
    #     self.train_model_using_dexter_dataset(sentiment_processor, spotter, afinn_filename)
    #     sentiment_processor.save_model(filename)
    #     return sentiment_processor


if __name__ == "__main__":
    fg = FilterGolden()

    dd = DatasetDexter()
    wd = WikipediaDataset()

    dexter_json_doc_list = dd.get_dexter_dataset(
        FileLocations.get_dropbox_dexter_path(), 'saliency-dataset.json')
    golden_saliency_by_entid_by_docid = dd.get_golden_saliency_by_entid_by_docid(
        dexter_json_doc_list, wd)

    #check which are still valid

    wikititle_by_id = wd.get_wikititle_by_id()
    not_found_count = 0
    count = 0
    multiple_wid_count = 0

    for docid in golden_saliency_by_entid_by_docid.keys():
        for entity_id in golden_saliency_by_entid_by_docid[docid].keys():
import logging

from sellibrary.wiki.wikipedia_datasets import WikipediaDataset

# set up logging
handler = logging.StreamHandler()
handler.setFormatter(
    logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)

if __name__ == "__main__":
    ds = WikipediaDataset()
    # this requires extract_curid_by_wikititle_trie to have been run first
    ds.convert_link_graph_to_csr_and_csc()
from sellibrary.wiki.wikipedia_datasets import WikipediaDataset


def do_stuff(word, wds):
    id = wds.get_id_from_wiki_title(word)
    print(id)
    in_degree = wds.get_entity_in_degree(id)
    print('in_degree', in_degree)
    out_degree = wds.get_entity_out_degree(id)
    print('out_degree', out_degree)


if __name__ == "__main__":

    wds = WikipediaDataset()

    madrid = 41188263
    barcelona = 4443
    apple_inc = 8841385
    steve_jobs = 1563047
    steve_jobs = 7412236

    word = 'zorb'
    do_stuff(word, wds)

    word = 'united_states'
    do_stuff(word, wds)
Ejemplo n.º 17
0
    def __init__(self, features_to_zero = []):

        # __ instance variables
        self.ds = WikipediaDataset()
        self.features_to_zero = features_to_zero
Ejemplo n.º 18
0
class GraphUtils:
    # Set up logging
    handler = logging.StreamHandler()
    handler.setFormatter(
        logging.Formatter(
            '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
    logger = logging.getLogger(__name__)
    logger.addHandler(handler)
    logger.propagate = False
    logger.setLevel(logging.INFO)
    wds = WikipediaDataset()
    #_____ const
    REALLY_BIG_NUMBER = 100
    VERY_SMALL_NUMBER = 0.001

    def __init__(self):
        pass

    def relateness(self, entity_id_a, entity_id_b):
        # return the milne and witten relatedness value
        link_to_a = set(self.wds.get_links_to(entity_id_a))
        link_to_b = set(self.wds.get_links_to(entity_id_b))
        intersect = link_to_a.intersection(link_to_b)
        size_a = len(link_to_a)
        size_b = len(link_to_b)
        size_int = len(intersect)
        self.logger.debug(' %d, %d %d ', size_a, size_b, size_int)
        p1 = np.log2(max(size_a, size_b))
        p2 = np.log2(max(size_int, 1))
        p3 = np.log2(
            5
        )  # this needs to set correctly - but as we just take the median - may not matter
        p4 = np.log2(max(1, min(size_a, size_b)))
        if p3 == p4:
            self.logger.warning(
                'Error calculating relatedness, denominator is 0. Can only crudely estimate. p1=%f, p2=%f, p3=%f, p4=%f ',
                p1, p2, p3, p4)
            relatedness = (p1 - p2) / GraphUtils.VERY_SMALL_NUMBER
        else:
            relatedness = (p1 - p2) / (p3 - p4)
        return relatedness

    def calc_v1_matrix(self, entity_id_list):
        # find links that are within the set of nodes we are passed, and
        # all those in bound to them, and out bound from them
        from_list = []
        to_list = []
        value_list = []
        max_id = 0
        wds = WikipediaDataset()
        for from_entity_id in entity_id_list:
            link_to = wds.get_links_to(from_entity_id)
            for v in link_to:
                from_list.append(from_entity_id)
                to_list.append(v)
                value_list.append(1)
                if v > max_id:
                    max_id = v

            link_from = set(wds.get_links_from(from_entity_id))
            for v in link_from:
                from_list.append(v)
                to_list.append(from_entity_id)
                value_list.append(1)
                if v > max_id:
                    max_id = v
        # TODO The following line threw a Value error (row index exceeds matrix dimentions) here on docid 579, and docid 105
        try:
            mtx = sparse.coo_matrix((value_list, (from_list, to_list)),
                                    shape=(max_id + 1, max_id + 1))
            pass
        except ValueError as e:
            self.logger.warning(
                'An error occurred returning None rather that a V1 matrix. %s',
                e)
            return None
        return mtx

    def get_links_totally_within(self, entity_id_list):
        from_list = []
        to_list = []
        value_list = []
        v0_vertice_set = set(entity_id_list)
        wds = WikipediaDataset()
        for entity_id in v0_vertice_set:
            links_to = wds.get_links_to(entity_id)
            for link_to in links_to:
                if link_to in v0_vertice_set:
                    to_list.append(entity_id)
                    from_list.append(link_to)
                    value_list.append(1)
        return from_list, to_list, value_list

    def calc_v0_matrix(self, entity_id_list):
        # only find links that are within the set of nodes we are passed
        from_list, to_list, value_list = self.get_links_totally_within(
            entity_id_list)
        l = []
        l.extend(from_list)
        l.extend(to_list)
        try:
            if len(l) > 0:
                max_id = max(l)  # l could be empty
            else:
                max_id = 1  # this occured on docid = 214
            mtx = sparse.coo_matrix((value_list, (from_list, to_list)),
                                    shape=(max_id + 1, max_id + 1))
        except ValueError as e:
            self.logger.warning(
                'Could not calculate coo matrix. from_list = %s, to_list = %s, value_list = %s ',
                from_list, to_list, value_list)
            logging.exception('')
            mtx = None
        return mtx

    def get_diameter(self,
                     mtx,
                     entity_id_list,
                     print_names=False,
                     break_early=False,
                     optional_docId=-1):
        fairness_by_entity_id = {}
        for entity_id in entity_id_list:
            fairness_by_entity_id[entity_id] = 0

        self.logger.info(
            'docid = %s, Calculating distances for %d  entities. Approx duration %d sec =( %f min )',
            str(optional_docId), len(entity_id_list),
            len(entity_id_list) * 3,
            len(entity_id_list) * 3 / 60.0)

        max_dist = 0
        count = 0
        for entity_id_1 in entity_id_list:
            self.logger.info('%d/%d Calculating distances from entity_id %d ',
                             count, len(entity_id_list), entity_id_1)
            distances, predecessors = dijkstra(mtx,
                                               indices=entity_id_1,
                                               return_predecessors=True)
            for entity_id_2 in entity_id_list:
                if print_names:
                    pass
                    #TODO load cache and print names
                    e1_name = str(entity_id_1)
                    e2_name = str(entity_id_2)
                    print('from ', e1_name, '(', entity_id_1, ') to', e2_name,
                          '(', entity_id_2, ') distance',
                          distances[entity_id_2])
                d = distances[entity_id_2]
                if not np.isinf(d):
                    if d > max_dist:
                        max_dist = d

                    fairness_by_entity_id[
                        entity_id_1] = fairness_by_entity_id[entity_id_1] + d
                    fairness_by_entity_id[
                        entity_id_2] = fairness_by_entity_id[entity_id_2] + d
            count += 1
            if break_early and count > 3:
                self.logger.warning(
                    'Breaking early, so we will have a smaller graph. ')
                break

        print('diameter ', max_dist)
        return max_dist, fairness_by_entity_id

    def get_mean_median_in_degree(self,
                                  mtx,
                                  full_set_entity_ids,
                                  break_early=False):
        if break_early:
            self.logger.warning('Breaking early, returning made up results')
            return 1, 2
        if mtx is None:
            return 0, 0

        csc = mtx.tocsc()
        list = []
        for id in full_set_entity_ids:
            s = csc.getcol(id).sum()
            list.append(s)
        mean = np.mean(list)
        median = np.median(list)
        return mean, median

    def get_mean_median_out_degree(self,
                                   mtx,
                                   full_set_entity_ids,
                                   break_early=False):
        if break_early:
            self.logger.warning('Breaking early, returning made up results')
            return 1, 2

        if mtx is None:
            return 0, 0

        csr = mtx.tocsr()
        list = []
        for id in full_set_entity_ids:
            s = csr.getrow(id).sum()
            list.append(s)
        mean = np.mean(list)
        median = np.median(list)
        return mean, median

    def get_mean_median_degree(self,
                               mtx,
                               full_set_entity_ids,
                               break_early=False):
        degree_by_entity_id = {}
        if break_early:
            self.logger.warning('Breaking early, returning made up results')
            for entity_id in full_set_entity_ids:
                degree_by_entity_id[entity_id] = 1
            return 1, 2, degree_by_entity_id

        if mtx is None:
            for entity_id in full_set_entity_ids:
                degree_by_entity_id[entity_id] = 0
            return 0, 0, degree_by_entity_id

        csc = mtx.tocsc()
        for id in full_set_entity_ids:
            s = csc.getcol(id).sum()
            if id in degree_by_entity_id:
                degree_by_entity_id[id] = degree_by_entity_id[id] + s
            else:
                degree_by_entity_id[id] = s

        csr = mtx.tocsr()
        for id in full_set_entity_ids:
            s = csr.getrow(id).sum()
            if id in degree_by_entity_id:
                degree_by_entity_id[id] = degree_by_entity_id[id] + s
            else:
                degree_by_entity_id[id] = s

        x = list(degree_by_entity_id.values())
        mean = np.mean(x)
        median = np.median(x)
        return mean, median, degree_by_entity_id

    def get_degree_for_entity(self, mtx, entity_id):
        csc = mtx.tocsc()
        s1 = csc.getcol(entity_id).sum()
        csr = mtx.tocsr()
        s2 = csr.getrow(entity_id).sum()
        return s1 + s2

    def get_closeness_by_entity_id(self, fairness_by_entity_id):
        closeness_by_entity_id = {}
        for entity_id in fairness_by_entity_id.keys():
            if fairness_by_entity_id[entity_id] != 0.0:
                closeness_by_entity_id[
                    entity_id] = 1.0 / fairness_by_entity_id[entity_id]
            else:
                closeness_by_entity_id[
                    entity_id] = GraphUtils.REALLY_BIG_NUMBER

        return closeness_by_entity_id

    def get_dense_down_sampled_adj_graph(self, mtx):
        # create a sparse matrix.
        entity_id_by_short_id = {}
        short_id_by_entity_id = {}

        t1 = 0
        t2 = 0
        if len(mtx.col) > 0:
            t1 = mtx.col.max()  # get max of this ndarray
        if len(mtx.row) > 0:
            t2 = mtx.row.max()  # get max of this ndarray
        max_id = max(t1, t2) + 1

        full_set_entity_ids = []
        full_set_entity_ids.extend(mtx.col)
        full_set_entity_ids.extend(mtx.row)
        count = 0
        for entity_id in full_set_entity_ids:
            entity_id_by_short_id[count] = entity_id
            short_id_by_entity_id[entity_id] = count
            count += 1

        # down sample the sparse matrix
        from_list = []
        to_list = []
        value_list = mtx.data
        for i in range(len(mtx.row)):
            from_list.append(short_id_by_entity_id[mtx.row[i]])
            to_list.append(short_id_by_entity_id[mtx.col[i]])

        max_id = 1
        if len(from_list) > 0:
            max_id = max(max_id, max(from_list)) + 1
        if len(to_list) > 0:
            max_id = max(max_id, max(to_list)) + 1

        mtx_small = sparse.coo_matrix((value_list, (from_list, to_list)),
                                      shape=(max_id, max_id))
        # obtain a dense matrix in the down sampled space
        dense = nx.from_scipy_sparse_matrix(mtx_small)

        return dense, entity_id_by_short_id, short_id_by_entity_id, from_list, to_list, mtx_small

    def calc_centrality(self, mtx, full_set_entity_ids):

        centrality_by_entity_id = {}
        if mtx is None:
            for entity_id in full_set_entity_ids:
                centrality_by_entity_id[entity_id] = 0.0
            return centrality_by_entity_id

        # create a sparse matrix.
        dense, entity_id_by_short_id, short_id_by_entity_id, from_list, to_list, mtx_small = self.get_dense_down_sampled_adj_graph(
            mtx)

        # calc centrality
        try:
            centrality = nx.eigenvector_centrality_numpy(dense)
            # convert centrality index back to the original space
            for k in centrality.keys():
                centrality_by_entity_id[
                    entity_id_by_short_id[k]] = centrality[k]
            self.logger.info(centrality_by_entity_id)

        except ValueError as e:
            self.logger.warning(
                'Could not calculate centrality. defaulting to 1')
            for entity_id in full_set_entity_ids:
                centrality_by_entity_id[entity_id] = 1
            # self.logger.warning('mtx_small %s:', mtx_small)
            self.logger.warning("Nodes in G: %s ", dense.nodes(data=True))
            self.logger.warning("Edges in G: %s ", dense.edges(data=True))
            logging.exception('')
        except TypeError as e:
            self.logger.warning(
                'Could not calculate centrality. defaulting to 1')
            for entity_id in full_set_entity_ids:
                centrality_by_entity_id[entity_id] = 1
            # self.logger.warning('mtx_small %s:', mtx_small)
            self.logger.warning("Nodes in G: %s ", dense.nodes(data=True))
            self.logger.warning("Edges in G: %s ", dense.edges(data=True))
            logging.exception('')
        except KeyError as e:
            self.logger.warning(
                'Could not calculate centrality. defaulting to 1')
            for entity_id in full_set_entity_ids:
                centrality_by_entity_id[entity_id] = 1
            # self.logger.warning('mtx_small %s:', mtx_small)
            self.logger.warning("Nodes in G: %s ", dense.nodes(data=True))
            self.logger.warning("Edges in G: %s ", dense.edges(data=True))
            logging.exception('')
        except nx.NetworkXException as e:
            self.logger.warning(
                'Could not calculate centrality. defaulting to 1')
            for entity_id in full_set_entity_ids:
                centrality_by_entity_id[entity_id] = 1
            self.logger.warning('mtx_small %s:', mtx_small)
            self.logger.warning("Nodes in G: %s ", dense.nodes(data=True))
            self.logger.warning("Edges in G: %s ", dense.edges(data=True))
            logging.exception('')

        return centrality_by_entity_id

    def calc_all_features(self, mtx, break_early=False, optional_docId=-1):
        full_set_entity_ids = self.get_unique_set_of_entity_ids(mtx)
        if break_early:
            self.logger.warning("Limiting the number of heavy entities to 5")
            l = list(full_set_entity_ids)
            full_set_entity_ids = set(l[0:min(5, len(l))])

        self.logger.info(
            'Calculating diameter & fairness on matrix with %d vertices',
            len(full_set_entity_ids))
        diameter, fairness_by_entity_id = self.get_diameter(
            mtx,
            full_set_entity_ids,
            break_early=break_early,
            optional_docId=optional_docId)
        feature_1_graph_size = len(full_set_entity_ids)
        self.logger.info('graph size: %d', feature_1_graph_size)
        feature_2_graph_diameter = diameter
        self.logger.info('diameter: %d', diameter)
        mean, median = self.get_mean_median_in_degree(mtx, full_set_entity_ids,
                                                      break_early)
        if median == 0.0:
            self.logger.warning('mean: %f median: %f', mean, median)
            feature_4_in_degree_mean_median = 0  # this can happen from small sets of input entities with no links between them
        else:
            feature_4_in_degree_mean_median = mean / median
        self.logger.info('in degree mean/median: %f',
                         feature_4_in_degree_mean_median)
        mean, median = self.get_mean_median_out_degree(mtx,
                                                       full_set_entity_ids,
                                                       break_early)
        if median == 0.0:
            self.logger.warning('mean: %f median: %f', mean, median)
            feature_5_out_degree_mean_median = 0  # valid for this to be 0
        else:
            feature_5_out_degree_mean_median = mean / median
        self.logger.info('out degree mean/median: %f',
                         feature_5_out_degree_mean_median)
        self.logger.info('calculating mean and median degrees.')
        mean, median, degree_by_entity_id = self.get_mean_median_degree(
            mtx, full_set_entity_ids, break_early=break_early)
        feature_3_node_degree_by_entity_id = degree_by_entity_id
        self.logger.info('node_degree_by_entity_id: %s',
                         feature_3_node_degree_by_entity_id)
        if median == 0.0:
            self.logger.warning('mean: %f median: %f', mean, median)
            feature_6_degree_mean_median = 0  # valid for this to be 0
        else:
            feature_6_degree_mean_median = mean / median
        self.logger.info('degree mean/median: %f',
                         feature_6_degree_mean_median)
        feature_7_fairness_by_entity_id = fairness_by_entity_id
        self.logger.info('fairness_by_entity_id: %s', fairness_by_entity_id)
        feature_8_closeness_by_entity_id = self.get_closeness_by_entity_id(
            fairness_by_entity_id)
        self.logger.info('closeness by entity id: %s',
                         feature_8_closeness_by_entity_id)
        feature_9_centrality_by_entity_id = self.calc_centrality(
            mtx, full_set_entity_ids)
        self.logger.info('centrality by entity id: %s',
                         feature_9_centrality_by_entity_id)
        return feature_1_graph_size, feature_2_graph_diameter, feature_3_node_degree_by_entity_id, feature_4_in_degree_mean_median, \
               feature_5_out_degree_mean_median, feature_6_degree_mean_median, feature_7_fairness_by_entity_id, feature_8_closeness_by_entity_id, feature_9_centrality_by_entity_id

    def filter_low_milne_and_witten_relatedness(self, mtx):
        if mtx is None:
            return None

        self.logger.info('Calculating milne and witten relatedness')
        col_values = []
        row_values = []
        data_values = []
        max_id = 0
        for i in range(len(mtx.data)):
            from_entity_id = mtx.row[i]
            to_entity_id = mtx.col[i]
            relatedness = self.relateness(from_entity_id, to_entity_id)
            if relatedness > 0.0:
                col_values.append(mtx.col[i])
                row_values.append(mtx.row[i])
                data_values.append(mtx.data[i])
                if mtx.col[i] > max_id:
                    max_id = mtx.col[i]
                if mtx.row[i] > max_id:
                    max_id = mtx.row[i]

        mtx = sparse.coo_matrix((data_values, (row_values, col_values)),
                                shape=(max_id + 1, max_id + 1))
        return mtx

    def get_unique_set_of_entity_ids(self, mtx):
        if mtx is None:
            return set()
        full_set = set(mtx.col)
        full_set.update(mtx.row)
        return full_set
Ejemplo n.º 19
0
from sellibrary.wiki.wikipedia_datasets import WikipediaDataset
from sel.file_locations import FileLocations

# dense to sparse

# set up logging
handler = logging.StreamHandler()
handler.setFormatter(
    logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)

if __name__ == "__main__":
    ds = WikipediaDataset()
    wikititle_marisa_trie = ds.get_wikititle_case_insensitive_marisa_trie()
    logger.info('Creating dictionary')
    wikititle_by_id = {}
    for k in wikititle_marisa_trie.keys():
        wid = wikititle_marisa_trie.get(k)[0][0]
        wikititle_by_id[wid] = k

    logger.info('complete')

    output_filename = FileLocations.get_dropbox_wikipedia_path(
    ) + 'wikititle_by_id.pickle'
    logger.info('About to write %s', output_filename)
    with open(output_filename, 'wb') as handle:
        pickle.dump(wikititle_by_id, handle, protocol=pickle.HIGHEST_PROTOCOL)
    logger.info('file written = %s', output_filename)
Ejemplo n.º 20
0
if __name__ == "__main__":

    filename = FileLocations.get_dropbox_intermediate_path() + 'sel.pickle'
    build_model = False

#    smb = SelModelBuilder()

    # if build_model:
    #     sentiment_processor = smb.train_and_save_model(filename)
    # else:
    #     sentiment_processor = SentimentProcessor()
    #     sentiment_processor.load_model(filename)

    dd = DatasetDexter()
    wikipediaDataset = WikipediaDataset()
    document_list = dd.get_dexter_dataset(path=FileLocations.get_dropbox_dexter_path())
    spotter = GoldenSpotter(document_list, wikipediaDataset)

    golden_saliency_by_entid_by_docid = dd.get_golden_saliency_by_entid_by_docid(document_list, wikipediaDataset)

    wikititle_by_id = wikipediaDataset.get_wikititle_by_id()

    docid = 1
    for entity_id in golden_saliency_by_entid_by_docid[docid]:
        logger.info('___________________')
        logger.info(entity_id)
        entity_id2 = wikipediaDataset.get_wikititle_id_from_id(entity_id)

        if entity_id in  wikititle_by_id:
            logger.info(wikititle_by_id[entity_id])
Ejemplo n.º 21
0
    output_filename = dropbox_intermediate_path + 'wp_joined.txt'  #'joined_sel_sent_and_tf.txt'

    # Load File A
    X1, y1, docid_array1, entity_id_array1 = load_feature_matrix(
        feature_filename=filename_A,
        feature_names=file_A_feature_names,
        entity_id_index=1,
        y_feature_index=2,
        first_feature_index=4,
        number_features_per_line=len(file_A_feature_names) + 4,
        tmp_filename='/tmp/temp_conversion_file.txt')

    print(y1.shape)
    dexter_dataset = DatasetDexter()
    wikipedia_dataset = WikipediaDataset()
    # fg = FilterGolden()
    # X1, y1, docid_array1, entity_id_array1 = fg.get_only_golden_rows(X1, y1, docid_array1, entity_id_array1, dexter_dataset,
    #                                                     wikipedia_dataset)

    document_list = dexter_dataset.get_dexter_dataset(
        path=FileLocations.get_dropbox_dexter_path())
    golden_saliency_by_entid_by_docid = dexter_dataset.get_golden_saliency_by_entid_by_docid(
        document_list, wikipedia_dataset)

    print(y1.shape)

    # Load File B
    X2, y2, docid_array2, entity_id_array2 = load_feature_matrix(
        feature_filename=filename_B,
        feature_names=file_B_feature_names,
    print('not_salient_list:' + str(not_salient_list))
    print('salient_list:' + str(salient_list))


if __name__ == "__main__":

    filename = FileLocations.get_dropbox_intermediate_path() + 'sel.pickle'
    build_model = False

    #    smb = SelModelBuilder()

    # if build_model:
    #     sentiment_processor = smb.train_and_save_model(filename)
    # else:
    #     sentiment_processor = SentimentProcessor()
    #     sentiment_processor.load_model(filename)

    dd = DatasetDexter()
    wikipediaDataset = WikipediaDataset()
    document_list = dd.get_dexter_dataset(
        path=FileLocations.get_dropbox_dexter_path())
    spotter = GoldenSpotter(document_list, wikipediaDataset)

    golden_saliency_by_entid_by_docid = dd.get_golden_saliency_by_entid_by_docid(
        document_list, wikipediaDataset)

    wikititle_by_id = wikipediaDataset.get_wikititle_by_id()

    show_doc_info(2)
Ejemplo n.º 23
0
class SELLightFeatureExtractor:

    # set up logging
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
    logger = logging.getLogger(__name__)
    logger.addHandler(handler)
    logger.propagate = False
    logger.setLevel(logging.INFO)

    def __init__(self, features_to_zero = []):

        # __ instance variables
        self.ds = WikipediaDataset()
        self.features_to_zero = features_to_zero

    def get_entity_set(self, entity_list):
        entity_set = set()
        name_by_entity_id = {}
        for e in entity_list:
            entity_set.add(e.entity_id)
            name = e.text
            name_by_entity_id[e.entity_id] = name
        return entity_set, name_by_entity_id

    # 1 Position ___________________________________________________________
    def calc_pos_for_entity(self, text, entity_list, entity_id):
        count = 0
        positions = []
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                i = (e2.start_char / len(text))
                positions.append(i)
                count += 1
        return np.min(positions), np.max(positions), np.mean(positions), np.std(positions)

    def calc_positions(self, text, entity_list, entity_id_set):
        pos = {}
        for entity_id in entity_id_set:
            positions = self.calc_pos_for_entity(text, entity_list, entity_id)
            pos[entity_id] = positions
        return pos

    # 2 First field position ______________________________________________________

    @staticmethod
    def find_nth(haystack, needle, n):
        start = haystack.find(needle)
        while start >= 0 and n > 1:
            start = haystack.find(needle, start + len(needle))
            n -= 1
        return start

    @staticmethod
    def get_normalised_first_pos(entity_list, entity_id, lower_bound, upper_bound):
        first_location = None
        normed = None
        anchor = None
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                anchor = e2.text
                if lower_bound <= e2.start_char < upper_bound:
                    if first_location is None:
                        first_location = e2.start_char
                        normed = (first_location - lower_bound) / (upper_bound - lower_bound)
                    else:
                        if e2.start_char < first_location:
                            first_location = e2.start_char
                            normed = (first_location - lower_bound) / (upper_bound - lower_bound)
        return normed, anchor

    # return the bounds of the first 3 sentences, mid sentences, and last 3 sentences
    def get_body_positions(self, body):
        end_of_third_sentence = self.find_nth(body, '.', 3) + 1
        if end_of_third_sentence == 0:
            end_of_third_sentence = len(body)
        num_sentences = body.count(".") + 1
        start_of_last_three_sentences = self.find_nth(body, '.', num_sentences - 3) + 1
        if start_of_last_three_sentences == 0:
            start_of_last_three_sentences = len(body)
        if start_of_last_three_sentences < end_of_third_sentence:
            start_of_last_three_sentences = end_of_third_sentence

        return [
            0, end_of_third_sentence,  # first 3 sentences
            end_of_third_sentence + 1, start_of_last_three_sentences,  # middle
            start_of_last_three_sentences + 1, len(body)  # last three sentences
        ]

    def calc_first_field_positions_for_entity(self, body, title, entity_list, entity_id, title_entity_list,
                                              title_entity_id):

        first_section_start, first_section_end, \
        mid_section_start, mid_section_end, \
        last_section_start, last_section_end = self.get_body_positions(body)

        norm_first = self.get_normalised_first_pos(entity_list, entity_id, first_section_start, first_section_end)
        norm_middle = self.get_normalised_first_pos(entity_list, entity_id, mid_section_start, mid_section_end)
        norm_end = self.get_normalised_first_pos(entity_list, entity_id, last_section_start, last_section_end)
        title_first_location = self.get_normalised_first_pos(title_entity_list, title_entity_id, 0, len(title))

        return norm_first[0], norm_middle[0], norm_end[0], title_first_location[0]

    def calc_first_field_positions(self, body, title, entity_list, entity_id_set, title_entity_list):
        first_field_positions_by_ent_id = {}
        for entity_id in entity_id_set:
            first_field_positions_by_ent_id[entity_id] = self.calc_first_field_positions_for_entity(body, title,
                                                                                                    entity_list,
                                                                                                    entity_id,
                                                                                                    title_entity_list,
                                                                                                    entity_id)
        return first_field_positions_by_ent_id

    # 3 Sentence Position ______________________________________________________

    @staticmethod
    def get_average_normalised_pos(entity_list, entity_id, lower_bound, upper_bound):
        normed_positions = []
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                if lower_bound <= e2.start_char < upper_bound:
                    normed = (e2.start_char - lower_bound) / (upper_bound - lower_bound)
                    normed_positions.append(normed)
        return normed_positions

    def calc_sentence_positions_for_entity(self, body, entity_list, entity_id):
        num_sentences = body.count(".") + 1
        start_index = 0
        normed_positions = []
        for sentence_num in range(1, num_sentences):
            end_index = self.find_nth(body, '.', sentence_num)
            normed_positions.extend(self.get_average_normalised_pos(entity_list, entity_id, start_index, end_index))
            start_index = end_index + 1  # save a loop by copying

        self.logger.debug('normed positions = %s ', normed_positions)
        return np.mean(normed_positions)

    def calc_sentence_positions(self, body, entity_list, entity_id_set):
        sentence_positions_by_ent_id = {}
        for entity_id in entity_id_set:
            sentence_positions_by_ent_id[entity_id] = self.calc_sentence_positions_for_entity(body, entity_list,
                                                                                              entity_id)
        return sentence_positions_by_ent_id

    # 4 field frequency  ___________________________________________________________


    def calc_field_frequency(self, body, entity_list, title_entity_list):
        first_section_start, first_section_end, \
        mid_section_start, mid_section_end, \
        last_section_start, last_section_end = self.get_body_positions(body)

        field_frequency_by_ent_id = {}

        for e2 in entity_list:
            if e2.entity_id not in field_frequency_by_ent_id:
                field_frequency_by_ent_id[e2.entity_id] = [0, 0, 0, 0]

            if first_section_start <= e2.start_char <= first_section_end:
                field_frequency_by_ent_id[e2.entity_id][0] = field_frequency_by_ent_id[e2.entity_id][0] + 1

            if mid_section_start <= e2.start_char <= mid_section_end:
                field_frequency_by_ent_id[e2.entity_id][1] = field_frequency_by_ent_id[e2.entity_id][1] + 1

            if last_section_start <= e2.start_char <= last_section_end:
                field_frequency_by_ent_id[e2.entity_id][2] = field_frequency_by_ent_id[e2.entity_id][2] + 1

        for e2 in title_entity_list:
            if e2.entity_id not in field_frequency_by_ent_id:
                field_frequency_by_ent_id[e2.entity_id] = [0, 0, 0, 0]
            field_frequency_by_ent_id[e2.entity_id][3] = field_frequency_by_ent_id[e2.entity_id][3] + 1

        return field_frequency_by_ent_id

    # 5 capitalization ___________________________________________________________

    def calc_capitalization_for_entity(self, text, entity_list, entity_id):
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                message = e2.text.strip()
                u = sum(1 for c in message if c.isupper())
                c = len(message)
                if (c == u):
                    return True
        return False

    # return True iff at least one mention of cj is capitalized
    def calc_capitalization(self, text, entity_list, entity_id_set):
        capitalization_by_ent_id = {}
        for entity_id in entity_id_set:
            capitalization_by_ent_id[entity_id] = self.calc_capitalization_for_entity(text, entity_list, entity_id)
        return capitalization_by_ent_id

    # 6 Uppercase ratio ___________________________________________________________

    def calc_uppercase_ratio_for_entity(self, text, entity_list, entity_id):
        count = 0
        upper_count = 0
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                message = e2.text
                u = sum(1 for c in message if c.isupper())
                c = len(message)
                count += c
                upper_count += u
        if count == 0:
            return 0
        else:
            return upper_count / count

    # maximum fraction of uppercase letters among the spots referring to cj
    def calc_uppercase_ratio(self, text, entity_list, entity_id_set):
        uppercase_ratio_by_ent_id = {}
        for entity_id in entity_id_set:
            uppercase_ratio_by_ent_id[entity_id] = self.calc_uppercase_ratio_for_entity(text, entity_list, entity_id)
        return uppercase_ratio_by_ent_id

    # 7 highlighting  ___________________________________________________________
    #
    # Not yet implemented, as we do not have this information

    # 8.1 Average Lengths in words ___________________________________________________________

    @staticmethod
    def calc_average_term_length_in_words_for_entity(text, entity_list, entity_id):
        length_list = []
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                message = e2.text.strip()
                space_count = sum(1 for c in message if c == ' ')
                word_count = space_count + 1
                length_list.append(word_count)

        return np.mean(length_list)

    # length in words
    def calc_average_term_length_in_words(self, text, entity_list, entity_id_set):
        average_term_length_by_ent_id = {}
        for entity_id in entity_id_set:
            average_term_length_by_ent_id[entity_id] = self.calc_average_term_length_in_words_for_entity(text,
                                                                                                         entity_list,
                                                                                                         entity_id)
        return average_term_length_by_ent_id

    # 8.2 Average Lengths in characters___________________________________________________________

    def calc_average_term_length_in_characters_for_entity(self, entity_list, entity_id):
        length_list = []
        for e2 in entity_list:
            if e2.entity_id == entity_id:
                message = e2.text.strip()
                char_length = len(message)
                length_list.append(char_length)
        return np.mean(length_list)

    # length in characters
    def calc_average_term_length_in_characters(self, entity_list, entity_id_set):
        average_term_length_by_ent_id = {}
        for entity_id in entity_id_set:
            average_term_length_by_ent_id[entity_id] = self.calc_average_term_length_in_characters_for_entity(
                entity_list, entity_id)
        return average_term_length_by_ent_id

    # 11 Is In Title ___________________________________________________________

    def calc_is_in_title(self, entity_list, title_entity_list):
        is_in_title_by_ent_id = {}
        for e2 in entity_list:  # ensure we have a full dictionary
            is_in_title_by_ent_id[e2.entity_id] = False
        for e2 in title_entity_list:
            is_in_title_by_ent_id[e2.entity_id] = True
        return is_in_title_by_ent_id

        # 12 link probabilties ___________________________________________________________

        # The link probability for a spot $s_i \in S_D$ is defined as the number of occurrences of $s_i$
        #  being a link to an entity in KB, divided by its
        #  total number of occurrences in KB.
        # i.e. how often is this anchor text actually a link to this entity

    # 13 is person - requires another download  _____________________________________________________

    # from https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/yago-naga/yago/downloads/
    # yagoSimpleTypes
    # SIMPLETAX : A simplified rdf:type system. This theme contains all instances, and links them with rdf:type facts to the leaf level of WordNet (use with yagoSimpleTaxonomy)
    # TSV version


    # 14 Entity frequency ___________________________________________________________

    def calc_entity_frequency_for_entity(self, text, entity_id, name_by_entity_id):
        count = text.count(name_by_entity_id[entity_id])
        # print(entity_id,name_by_entity_id[entity_id],count)
        return count

    def calc_entity_frequency(self, text, entity_id_set, name_by_entity_id):
        entity_frequency_by_ent_id = {}
        for entity_id in entity_id_set:
            entity_frequency_by_ent_id[entity_id] = self.calc_entity_frequency_for_entity(text, entity_id,
                                                                                          name_by_entity_id)
        return entity_frequency_by_ent_id

    # 15 distinct mentions  ___________________________________________________________

    # how is this different to the number of mentions?

    # 16 no ambiguity ___________________________________________________________


    # 17 ambiguity ___________________________________________________________

    # calculated as : 1 - reciprocal num candidate entities for spot

    # 18 commonness___________________________________________________________

    # commonness - when a spot points to many candidate entities, ratio number that point to entity A : Number that point to any entity.


    # 19 max commoness x max link probability ___________________________________

    # 20 entity degree ___________________________________
    # In-degree, out-degree and (undirected) degree of cj in the Wikipedia citation graph




    def calc_degrees(self, entity_id_set):
        entity_frequency_by_ent_id = {}
        for entity_id in entity_id_set:
            in_degree = self.ds.get_entity_in_degree(entity_id)
            out_degree = self.ds.get_entity_out_degree(entity_id)
            degree = in_degree + out_degree  # self.ds.get_entity_degree(entity_id)
            result = [in_degree, out_degree, degree]
            entity_frequency_by_ent_id[entity_id] = result
            self.logger.info('entity_id %d in out and total degrees: %s', entity_id, result)
        return entity_frequency_by_ent_id

    # 21 entity degree x max commoness ___________________________________


    # 22 document_length___________________________________________________________

    def calc_document_length(self, body, entity_id_set):
        document_length_by_ent_id = {}
        for entity_id in entity_id_set:
            document_length_by_ent_id[entity_id] = len(body)
        return document_length_by_ent_id

    # Combine all light features, on a per entity basis ___________________________________________________________

    def calc_light_features(self, body, title, entity_list, entity_id_set, name_by_entity_id, title_entity_list,
                            very_light=False):

        self.logger.info('calc_light_features 1')
        position_features_by_ent_id = self.calc_positions(body, entity_list, entity_id_set)  # 1
        self.logger.info('calc_light_features 2')
        field_positions_by_ent_id = self.calc_first_field_positions(body, title, entity_list, entity_id_set,
                                                                    title_entity_list)  # 2
        self.logger.info('calc_light_features 3')
        sentence_positions_by_ent_id = self.calc_sentence_positions(body, entity_list, entity_id_set)  # 3
        self.logger.info('calc_light_features 4')
        frequency_by_ent_id = self.calc_field_frequency(body, entity_list, title_entity_list)  # 4
        self.logger.info('calc_light_features 5')
        capitalization_by_ent_id = self.calc_capitalization(body, entity_list, entity_id_set)  # 5
        self.logger.info('calc_light_features 6')
        uppercase_ratio_by_ent_id = self.calc_uppercase_ratio(body, entity_list, entity_id_set)  # 6
        self.logger.info('calc_light_features 8.1')

        term_length_w_by_ent_id = self.calc_average_term_length_in_words(body, entity_list, entity_id_set)  # 8.1
        self.logger.info('calc_light_features 8.2')
        term_length_c_by_ent_id = self.calc_average_term_length_in_characters(entity_list, entity_id_set)  # 8.2
        self.logger.info('calc_light_features 11')

        title_by_ent_id = self.calc_is_in_title(entity_list, title_entity_list)  # 11
        self.logger.info('calc_light_features 14')
        entity_frequency_by_ent_id = self.calc_entity_frequency(body, entity_id_set, name_by_entity_id)  # 14
        self.logger.info('calc_light_features 20')

        if very_light:
            degrees_by_ent_id = {}
            for entity_id in entity_frequency_by_ent_id.keys():
                degrees_by_ent_id[entity_id] = [0, 0, 0]
        else:
            degrees_by_ent_id = self.calc_degrees(entity_id_set)  # 20
        self.logger.info('calc_light_features 22')
        doc_length_by_ent_id = self.calc_document_length(body, entity_id_set)  # 22

        self.logger.info('Reshaping results for document')

        results = {}
        for entity_id in entity_id_set:
            feature_list = []
            feature_list.extend(position_features_by_ent_id[entity_id])  # 1: 4 position features
            feature_list.extend(field_positions_by_ent_id[entity_id])  # 2
            feature_list.append(sentence_positions_by_ent_id[entity_id])  # 3
            feature_list.extend(frequency_by_ent_id[entity_id])  # 4
            feature_list.append(capitalization_by_ent_id[entity_id])  # 5
            feature_list.append(uppercase_ratio_by_ent_id[entity_id])  # 6 : 1 uppercase feature

            feature_list.append(term_length_w_by_ent_id[entity_id])  # 8.1 :
            feature_list.append(term_length_c_by_ent_id[entity_id])  # 8.2 :

            feature_list.append(title_by_ent_id[entity_id])  # 11 :

            feature_list.append(entity_frequency_by_ent_id[entity_id])  # 14 : 1 entity frequency feature

            feature_list.extend(degrees_by_ent_id[entity_id])  # 20 :
            feature_list.append(doc_length_by_ent_id[entity_id])  # 22 :

            # zero some features in order to do sensitivity checking
            for index in self.features_to_zero:
                if index>= 0 and index < len(feature_list):
                    feature_list[index] = 0

            results[entity_id] = feature_list
        return results

    # ___________________________________________________________


    def get_entity_saliency_list(self, body, title, spotter, very_light=False, spotter_confidence = 0.5):
        entity_list = spotter.get_entity_candidates(body, spotter_confidence)
        entity_id_set, name_by_entity_id = self.get_entity_set(entity_list)
        title_entity_list = spotter.get_entity_candidates(title, spotter_confidence)
        features_by_ent_id = self.calc_light_features(body, title, entity_list, entity_id_set, name_by_entity_id,
                                                      title_entity_list, very_light)
        title_entity_id_set, title_name_by_entity_id = self.get_entity_set(title_entity_list)
        return entity_list, entity_id_set, features_by_ent_id, name_by_entity_id, title_entity_list, title_entity_id_set

    def get_feature_list_by_ent(self, body, title, spotter, very_light=False, spotter_confidence = 0.5):
        entity_list, entity_id_set, features_by_ent_id, name_by_entity_id, title_entity_list, title_entity_id_set = \
            self.get_entity_saliency_list(body, title, spotter, very_light, spotter_confidence = spotter_confidence)
        return features_by_ent_id, name_by_entity_id
class ExtractAnchorText:
    def __init__(self):
        # Set up logging
        self.wikiDS = WikipediaDataset()
        handler = logging.StreamHandler()
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'))
        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(handler)
        self.logger.propagate = False
        self.logger.setLevel(logging.INFO)

    def get_intermediate_path(self):
        if sys.platform == 'win32':
            path = 'C:\\temp\\'
        else:
            path = '/Users/dsluis/Data/intermediate/'
        # self.logger.info(path)
        return path

    # take about 3.5 hours to run
    # link type BODY or LINK
    def create_anchor_text_marisa_trie(self,
                                       case_insensitive=True,
                                       link_type='BODY'):
        wikititle_trie = self.wikiDS.get_wikititle_case_insensitive_marisa_trie(
        )

        input_file = gzip.open("E:\\tmp\\" + 'wikipedia-dump.json.gz',
                               'rt',
                               encoding='utf-8')

        unique_text = []
        unique_wids = []
        unique_text_set = set()

        fname_prefix = self.get_intermediate_path(
        ) + 'text_marisa_trie.' + link_type.lower() + '.'
        if case_insensitive:
            fname_prefix = fname_prefix + 'case_insensitive.'

        count = 1
        cache_hits = 0
        cache_misses = 0
        total_link_count = 0
        specific_link_count = 0
        line = ''
        while count < 25000000 and line is not None:  # TODO check termination and remove magic number
            log_progress = (count < 50000 and count % 1000 == 0)
            if log_progress:
                self.logger.info('starting gc ')
                gc.collect(
                )  # have no real reason to think this is needed or will help the memory issue
                self.logger.info(
                    '%d lines processed. links processed = %d, total_links = %d, percentage=%f, cache_hits = %d, '
                    'cache_misses = %d', count, specific_link_count,
                    total_link_count,
                    (specific_link_count / float(total_link_count)),
                    cache_hits, cache_misses)

            save_progress = count % 1000000 == 0 or count == 10
            if save_progress:
                self.logger.info('starting gc ')
                gc.collect(
                )  # have no real reason to think this is needed or will help the memory issue
                self.logger.info(
                    "%d lines processed. links processed = %d, total_links = %d, percentage=%f, "
                    "cache_hits = %d, cache_misses = %d", count,
                    specific_link_count, total_link_count,
                    (specific_link_count / float(total_link_count)),
                    cache_hits, cache_misses)
                marisa_trie_filename = fname_prefix + str(count) + '.pickle'
                # t = marisa_trie.Trie(keys)
                # see http://marisa-trie.readthedocs.io/en/latest/tutorial.html
                fmt = "<Lb"
                # one long unsign 32 bit integer.
                #  see https://docs.python.org/3/library/struct.html#format-strings
                t2 = marisa_trie.RecordTrie(fmt, zip(unique_text, unique_wids))
                self.logger.info('about to save to %s', marisa_trie_filename)
                with open(marisa_trie_filename, 'wb') as handle:
                    pickle.dump(t2, handle, protocol=pickle.HIGHEST_PROTOCOL)
                self.logger.info('written  %s', marisa_trie_filename)

            line = input_file.readline()
            if line is not None and line != '':
                data = json.loads(line)
                # wikititle has underscores, 'title' has spaces
                # wid = data['wid']
                # title = data['title']
                # wikititle = data['wikiTitle']

                if 'links' in data:
                    links = data['links']
                    # pprint.pprint(links)
                    for link in links:
                        total_link_count += 1
                        if 'anchor' in link:
                            anchor_text = link['anchor']  # text
                            wikititle = link['id']  # text - matches wikititle
                            link_type = link['type']
                            if link_type == link_type:
                                specific_link_count += 1
                                if case_insensitive:
                                    wikititle = wikititle.lower()
                                    anchor_text = anchor_text.lower()

                                if wikititle in wikititle_trie:
                                    value_list = wikititle_trie[wikititle]
                                    curid = value_list[0][0]

                                    # if anchor_text not in unique_text_set:
                                    unique_text_set.add(anchor_text)
                                    unique_text.append(anchor_text)
                                    unique_wids.append((curid, 0))
                                    cache_hits += 1

                                else:
                                    # TODO change the below to a warning
                                    self.logger.debug(
                                        'wikititle %s not found in curid_by_wikititle_trie',
                                        wikititle)
                                    cache_misses += 1
            else:
                break

            count += 1

        self.logger.info('%d lines processed', count)
        marisa_trie_filename = fname_prefix + str(count) + '.pickle'
        # t = marisa_trie.Trie(keys)
        # see http://marisa-trie.readthedocs.io/en/latest/tutorial.html
        fmt = "<Lb"
        # one long unsign 32 bit integer.
        #  see https://docs.python.org/3/library/struct.html#format-strings
        t2 = marisa_trie.RecordTrie(fmt, zip(unique_text, unique_wids))
        self.logger.info('about to save to %s', marisa_trie_filename)
        with open(marisa_trie_filename, 'wb') as handle:
            pickle.dump(t2, handle, protocol=pickle.HIGHEST_PROTOCOL)
        self.logger.info('written  %s', marisa_trie_filename)