Ejemplo n.º 1
0
 def _wiki2vec_getitem(self, idx):
     if self.use_fast_sampler:
         if len(self._mention_infos) == 0: self._next_batch()
         idx = next(iter(self._mention_infos.keys()))
     if idx not in self._mention_infos:
         self._next_batch()
     mention_info = self._mention_infos.pop(idx)
     bag_of_nouns = self._bag_of_nouns_lookup[mention_info['page_id']]
     label = self.entity_label_lookup[mention_info['entity_id']]
     candidate_ids = self._get_candidate_ids(mention_info['mention'], label)
     p_prior = get_p_prior(self.entity_candidates_prior,
                           mention_info['mention'], candidate_ids)
     candidates = self._get_candidate_strs(candidate_ids.tolist())
     sample = {
         'bag_of_nouns':
         bag_of_nouns,
         'label':
         label,
         'p_prior':
         p_prior,
         'candidate_ids':
         candidate_ids,
         'candidate_mention_sim':
         torch.tensor([
             Levenshtein.ratio(mention_info['mention'], candidate)
             for candidate in candidates
         ])
     }
     self._mentions_per_page_ctr[mention_info['page_id']] -= 1
     if self._mentions_per_page_ctr[mention_info['page_id']] == 0:
         self._bag_of_nouns_lookup.pop(mention_info['page_id'])
     return sample
 def _get_batch_entity_page_mentions_lookup(self, page_ids):
     lookup = {}
     page_mention_infos_lookup = defaultdict(list)
     for mention_info in self._mention_infos.values():
         page_mention_infos_lookup[mention_info['page_id']].append(
             mention_info)
     for page_id in page_ids:
         page_mention_infos = page_mention_infos_lookup[page_id]
         entity_ids_for_page = []
         for mention_info in page_mention_infos:
             mention = mention_info['mention']
             if self.entity_candidates_prior.get(mention) is None:
                 approx_mentions = self.prior_approx_mapping.get(
                     unidecode.unidecode(mention).lower(), [])
                 candidate_ids = list(
                     set(
                         sum([
                             list(
                                 self.entity_candidates_prior.get(
                                     approx_mention, {}).keys())
                             for approx_mention in approx_mentions
                         ], [])))
             else:
                 candidate_ids = list(
                     self.entity_candidates_prior[mention].keys())
             prior = get_p_prior(self.entity_candidates_prior,
                                 self.prior_approx_mapping,
                                 mention_info['mention'],
                                 torch.tensor(candidate_ids))
             if len(prior) > 0:
                 most_common_idx = int(torch.argmax(prior))
                 entity_ids_for_page.append(candidate_ids[most_common_idx])
         lookup[page_id] = torch.tensor(
             entity_ids_for_page, device=self.entity_embeds.weight.device)
     return lookup
 def _getitem(self, idx):
     if self.use_fast_sampler:
         if len(self._mention_infos) == 0: self._next_batch()
         idx = next(iter(self._mention_infos.keys()))
     if idx not in self._mention_infos:
         self._next_batch()
     mention_info = self._mention_infos.pop(idx)
     sentence_spans = self._sentence_spans_lookup[mention_info['page_id']]
     page_content = self._page_content_lookup[mention_info['page_id']]
     label = self.entity_label_lookup[mention_info['entity_id']]
     candidate_ids = self._get_candidate_ids(mention_info['mention'], label)
     p_prior = get_p_prior(self.entity_candidates_prior,
                           self.prior_approx_mapping,
                           mention_info['mention'], candidate_ids)
     candidates = self._get_candidate_strs(candidate_ids.tolist())
     sample = {
         'sentence_splits':
         get_mention_sentence_splits(page_content,
                                     sentence_spans,
                                     mention_info,
                                     lim=self.page_content_lim),
         'label':
         label,
         'embedded_page_content':
         self._embedded_page_content_lookup[mention_info['page_id']],
         'entity_page_mentions':
         self.entity_embeds(
             self._entity_page_mentions_lookup[mention_info['page_id']]),
         'p_prior':
         p_prior,
         'candidate_ids':
         candidate_ids,
         'candidate_mention_sim':
         torch.tensor([
             Levenshtein.ratio(mention_info['mention'], candidate)
             for candidate in candidates
         ])
     }
     self._mentions_per_page_ctr[mention_info['page_id']] -= 1
     if self._mentions_per_page_ctr[mention_info['page_id']] == 0:
         self._sentence_spans_lookup.pop(mention_info['page_id'])
         self._page_content_lookup.pop(mention_info['page_id'])
         self._embedded_page_content_lookup.pop(mention_info['page_id'])
         self._entity_page_mentions_lookup.pop(mention_info['page_id'])
     return sample
Ejemplo n.º 4
0
 def _getitem_sum_encoder(self, idx):
   idx = self.with_label[idx]
   label = self.entity_label_lookup.get(self.labels[idx], -1)
   mention = self.mentions[idx]
   candidate_ids = get_candidate_ids(self.entity_candidates_prior,
                                     self.prior_approx_mapping,
                                     self.num_entities,
                                     self.num_candidates,
                                     mention,
                                     label)
   candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()])
   return {'mention_sentence': self.mention_sentences[idx],
           'page_token_cnts': self.page_token_cnts_lookup[self.mention_doc_id[idx]],
           'label': label,
           'p_prior': get_p_prior(self.entity_candidates_prior, self.prior_approx_mapping, mention, candidate_ids),
           'candidate_ids': candidate_ids,
           'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate)
                                                  for candidate in candidates])}
Ejemplo n.º 5
0
 def _getitem_wiki2vec(self, idx):
   idx = self.with_label[idx]
   label = self.entity_label_lookup.get(self.labels[idx], -1)
   mention = self.mentions[idx]
   candidate_ids = get_candidate_ids(self.entity_candidates_prior,
                                     self.prior_approx_mapping,
                                     self.num_entities,
                                     self.num_candidates,
                                     mention,
                                     label)
   bag_of_nouns = get_bag_of_nouns(self.documents[self.mention_doc_id[idx]])
   candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()])
   return {'label': label,
           'bag_of_nouns': bag_of_nouns,
           'p_prior': get_p_prior(self.entity_candidates_prior, self.prior_approx_mapping, mention, candidate_ids),
           'candidate_ids': candidate_ids,
           'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate)
                                                  for candidate in candidates])}
Ejemplo n.º 6
0
 def __getitem__(self, idx):
   idx = self.with_label[idx]
   label = self.entity_label_lookup.get(self.labels[idx]) or -1
   mention = self.mentions[idx]
   candidate_ids = get_candidate_ids(self.entity_candidates_prior,
                                     self.num_entities,
                                     self.num_candidates,
                                     mention,
                                     label)
   candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()])
   return {'sentence_splits': self.sentence_splits[idx],
           'label': label,
           'embedded_page_content': self.embedded_documents[self.mention_doc_id[idx]],
           'entity_page_mentions': embed_page_content(self.embedding,
                                                      self.token_idx_lookup,
                                                      ' '.join(self.mentions_by_doc_id[self.mention_doc_id[idx]])),
           'p_prior': get_p_prior(self.entity_candidates_prior, mention, candidate_ids),
           'candidate_ids': candidate_ids,
           'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate)
                                                  for candidate in candidates])}