コード例 #1
0
ファイル: conll_dataset.py プロジェクト: dmh43/entity-linking
 def __init__(self,
              cursor,
              entity_candidates_prior,
              embedding,
              token_idx_lookup,
              num_entities,
              num_candidates,
              entity_label_lookup,
              path='./AIDA-YAGO2-dataset.tsv'):
   self.cursor = cursor
   self.entity_candidates_prior = entity_candidates_prior
   self.embedding = embedding
   self.token_idx_lookup = token_idx_lookup
   self.num_entities = num_entities
   self.num_candidates = num_candidates
   with open(path, 'r') as fh:
     self.lines = fh.read().strip().split('\n')[:-1]
   self.documents = _get_documents(self.lines)
   self.embedded_documents = [embed_page_content(self.embedding, self.token_idx_lookup, document)
                              for document in self.documents]
   self.mentions = _get_mentions(self.lines)
   self.sentence_splits = _get_splits(self.documents, self.mentions)
   self.entity_page_ids = _get_entity_page_ids(self.lines)
   self.labels = _from_page_ids_to_entity_ids(cursor, self.entity_page_ids)
   self.with_label = [i for i, x in enumerate(self.labels) if x != -1]
   self.mention_doc_id = _get_doc_id_per_mention(self.lines)
   self.mentions_by_doc_id = _get_mentions_by_doc_id(self.lines)
   self.entity_label_lookup = entity_label_lookup
   self.entity_id_lookup = {int(label): entity_id for entity_id, label in self.entity_label_lookup.items()}
コード例 #2
0
 def _get_batch_embedded_page_content_lookup(self, page_ids):
     lookup = {}
     for page_id in page_ids:
         page_content = self._page_content_lookup[page_id]
         if len(page_content.strip()) > 5:
             lookup[page_id] = embed_page_content(self.embedding,
                                                  self.token_idx_lookup,
                                                  page_content)
     return lookup
コード例 #3
0
 def _get_batch_entity_page_mentions_lookup(self, page_ids):
     lookup = {}
     page_mention_infos_lookup = defaultdict(list)
     for mention_info in self._mention_infos.values():
         page_mention_infos_lookup[mention_info['page_id']].append(
             mention_info)
     for page_id in page_ids:
         page_mention_infos = page_mention_infos_lookup[page_id]
         content = ' '.join([
             mention_info['mention'] for mention_info in page_mention_infos
         ])
         if _.is_empty(page_mention_infos):
             lookup[page_id] = torch.tensor([])
         else:
             lookup[page_id] = embed_page_content(self.embedding,
                                                  self.token_idx_lookup,
                                                  content,
                                                  page_mention_infos)
     return lookup
コード例 #4
0
ファイル: conll_dataset.py プロジェクト: dmh43/entity-linking
 def __getitem__(self, idx):
   idx = self.with_label[idx]
   label = self.entity_label_lookup.get(self.labels[idx]) or -1
   mention = self.mentions[idx]
   candidate_ids = get_candidate_ids(self.entity_candidates_prior,
                                     self.num_entities,
                                     self.num_candidates,
                                     mention,
                                     label)
   candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()])
   return {'sentence_splits': self.sentence_splits[idx],
           'label': label,
           'embedded_page_content': self.embedded_documents[self.mention_doc_id[idx]],
           'entity_page_mentions': embed_page_content(self.embedding,
                                                      self.token_idx_lookup,
                                                      ' '.join(self.mentions_by_doc_id[self.mention_doc_id[idx]])),
           'p_prior': get_p_prior(self.entity_candidates_prior, mention, candidate_ids),
           'candidate_ids': candidate_ids,
           'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate)
                                                  for candidate in candidates])}
コード例 #5
0
def test_embed_page_content():
    embedding_dict = _.map_values(
        {
            '<PAD>': [-1],
            '<UNK>': [0],
            'MENTION_START_HERE': [-2],
            'MENTION_END_HERE': [-3],
            'a': [1],
            'b': [2],
            'c': [3],
            'd': [4]
        }, torch.tensor)
    token_idx_lookup = dict(
        zip(embedding_dict.keys(), range(len(embedding_dict))))
    embedding = nn.Embedding.from_pretrained(
        torch.stack([embedding_dict[token] for token in token_idx_lookup]))
    page_mention_infos = [{'offset': 2, 'mention': 'b c'}]
    page_content = 'a b c d'
    embedded = torch.tensor([[1], [-2], [2], [3], [-3], [4]])
    assert torch.equal(
        dt.embed_page_content(embedding, token_idx_lookup, page_content,
                              page_mention_infos), embedded)
コード例 #6
0
 def __init__(self,
              cursor,
              entity_candidates_prior,
              embedding,
              token_idx_lookup,
              num_entities,
              num_candidates,
              entity_label_lookup,
              path='./AIDA-YAGO2-dataset.tsv',
              use_wiki2vec=False,
              use_sum_encoder=False):
   self.cursor = cursor
   self.entity_candidates_prior = entity_candidates_prior
   self.embedding = embedding
   self.token_idx_lookup = token_idx_lookup
   self.num_entities = num_entities
   self.num_candidates = num_candidates
   with open(path, 'r') as fh:
     self.lines = fh.read().strip().split('\n')[:-1]
   self.documents = get_documents(self.lines)
   self.embedded_documents = [embed_page_content(self.embedding, self.token_idx_lookup, document)
                              for document in self.documents]
   self.mentions = get_mentions(self.lines)
   self.sentence_splits = get_splits(self.documents, self.mentions)
   self.mention_sentences = get_mention_sentences(self.documents, self.mentions)
   self.entity_page_ids = get_entity_page_ids(self.lines)
   self.labels = from_page_ids_to_entity_ids(cursor, self.entity_page_ids)
   self.with_label = [i for i, x in enumerate(self.labels) if x != -1]
   self.mention_doc_id = get_doc_id_per_mention(self.lines)
   self.mentions_by_doc_id = get_mentions_by_doc_id(self.lines)
   self.entity_label_lookup = entity_label_lookup
   self.entity_id_lookup = {int(label): entity_id for entity_id, label in self.entity_label_lookup.items()}
   self.use_wiki2vec = use_wiki2vec
   self.prior_approx_mapping = u.get_prior_approx_mapping(self.entity_candidates_prior)
   self.use_sum_encoder = use_sum_encoder
   self.stemmer = SnowballStemmer('english')
   self.page_token_cnts_lookup = [dict(Counter(u.to_idx(self.token_idx_lookup, self._stem(token))
                                               for token in parse_text_for_tokens(page_content)))
                                  for page_content in self.documents]