def _get_candidate_ids(self, mention, label): return get_candidate_ids(self.entity_candidates_prior, self.num_entities, self.num_candidates, mention, label, cheat=self.cheat)
def test_get_candidate_ids(): entity_candidates_prior = {'a': {1: 20}, 'b': {2: 12}, 'c': {3: 3}} num_entities = 300 num_candidates = 300 mention = 'b' label = 2 candidate_ids = df.get_candidate_ids(entity_candidates_prior, num_entities, num_candidates, mention, label) assert 2 in candidate_ids.tolist() assert len(candidate_ids) == num_candidates assert len(set(candidate_ids.tolist())) == num_candidates
def _getitem_sum_encoder(self, idx): idx = self.with_label[idx] label = self.entity_label_lookup.get(self.labels[idx], -1) mention = self.mentions[idx] candidate_ids = get_candidate_ids(self.entity_candidates_prior, self.prior_approx_mapping, self.num_entities, self.num_candidates, mention, label) candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()]) return {'mention_sentence': self.mention_sentences[idx], 'page_token_cnts': self.page_token_cnts_lookup[self.mention_doc_id[idx]], 'label': label, 'p_prior': get_p_prior(self.entity_candidates_prior, self.prior_approx_mapping, mention, candidate_ids), 'candidate_ids': candidate_ids, 'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate) for candidate in candidates])}
def _getitem_wiki2vec(self, idx): idx = self.with_label[idx] label = self.entity_label_lookup.get(self.labels[idx], -1) mention = self.mentions[idx] candidate_ids = get_candidate_ids(self.entity_candidates_prior, self.prior_approx_mapping, self.num_entities, self.num_candidates, mention, label) bag_of_nouns = get_bag_of_nouns(self.documents[self.mention_doc_id[idx]]) candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()]) return {'label': label, 'bag_of_nouns': bag_of_nouns, 'p_prior': get_p_prior(self.entity_candidates_prior, self.prior_approx_mapping, mention, candidate_ids), 'candidate_ids': candidate_ids, 'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate) for candidate in candidates])}
def __getitem__(self, idx): idx = self.with_label[idx] label = self.entity_label_lookup.get(self.labels[idx]) or -1 mention = self.mentions[idx] candidate_ids = get_candidate_ids(self.entity_candidates_prior, self.num_entities, self.num_candidates, mention, label) candidates = get_candidate_strs(self.cursor, [self.entity_id_lookup[cand_id] for cand_id in candidate_ids.tolist()]) return {'sentence_splits': self.sentence_splits[idx], 'label': label, 'embedded_page_content': self.embedded_documents[self.mention_doc_id[idx]], 'entity_page_mentions': embed_page_content(self.embedding, self.token_idx_lookup, ' '.join(self.mentions_by_doc_id[self.mention_doc_id[idx]])), 'p_prior': get_p_prior(self.entity_candidates_prior, mention, candidate_ids), 'candidate_ids': candidate_ids, 'candidate_mention_sim': torch.tensor([Levenshtein.ratio(mention, candidate) for candidate in candidates])}